1. 程式人生 > >機器學習:決策樹cart演算法在分類與迴歸的應用(上)

機器學習:決策樹cart演算法在分類與迴歸的應用(上)

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include

using namespace std;

//置信水平取0.95時的卡方表
const double CHI[18] = { 0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962 };
/*根據多維陣列計算卡方值*/
template
double cal_chi(Comparable **arr, int row, int col) {
	vector rowsum(row);
	vector colsum(col);
	Comparable totalsum = static_cast(0);//強制將0轉換為Comparable型
	//cout<<"observation"<right.first;
		   }
		};
/* 下面這三個資料結構是來存在在哪種屬性下的某一類的個數*/
typedef map MAP_REST_COUNT;
typedef map MAP_ATTR_REST;
typedef vector VEC_STATI;

const int ATTR_NUM = 6;       //自變數的維度
vector X(ATTR_NUM);
int rest_number;        //因變數的種類數,即類別數
vector > classes;      //把類別、對應的記錄數存放在一個數組中
int total_record_number;        //總的記錄數
vector > inputData;      //原始輸入資料
vector > testinputData;      //測試輸入資料

class node {
public:
	node* parent;       //父節點
	node* leftchild;        //左孩子節點
	node* rightchild;       //右孩子節點
	string cond;        //分枝條件
	string decision;        //在該節點上作出的類別判定
	double precision;       //判定的正確率
	int record_number;      //該節點上涵蓋的記錄個數
	int size;       //子樹包含的葉子節點的數目
	int index;      //層次遍歷樹,給節點標上序號
	double alpha;   //表面誤差率的增加量
	node() {
		parent = NULL;
		leftchild = NULL;
		rightchild = NULL;
		precision = 0.0;
		record_number = 0;
		size = 1;
		index = 0;
		alpha = 1.0;
	}
	node(node* p) {
		parent = p;
		leftchild = NULL;
		rightchild = NULL;
		precision = 0.0;
		record_number = 0;
		size = 1;
		index = 0;
		alpha = 1.0;
	}
	node(node* p, string c, string d) :cond(c), decision(d) {
		parent = p;
		leftchild = NULL;
		rightchild = NULL;
		precision = 0.0;
		record_number = 0;
		size = 1;
		index = 0;
		alpha = 1.0;
	}
	void printInfo() {
		cout << "index:" << index << "\tdecisoin:" << decision << "\tprecision:" << precision << "\tcondition:" << cond << "\tsize:" << size;
		if (parent != NULL)
			cout << "\tparent index:" << parent->index;
		if (leftchild != NULL)
			cout << "\tleftchild:" << leftchild->index << "\trightchild:" << rightchild->index;
		cout << endl;
	}
	void printTree() {
		printInfo();
		if (leftchild != NULL)
			leftchild->printTree();
		if (rightchild != NULL)
			rightchild->printTree();
	}
};
/* 讀取測試檔案資料,採取的是c++字串流的讀取方式
得到結果:testinputData 資料來源
*/
int readtestInput(string filename) {
	ifstream ifs(filename.c_str());
	if (!ifs) {
		cerr << "open inputfile failed!" << endl;
		return -1;
	}
	map catg;
	string line;
	getline(ifs, line);
	string item;
	istringstream strstm(line);
	strstm >> item;
	for (int i = 0; i> item;
		X[i] = item;
	}
	while (getline(ifs, line)) {
		vector conts(ATTR_NUM + 2);
		istringstream strstm(line);
		//strstm.str(line);
		for (int i = 0; i> item;
			conts[i] = item;
			if (i == conts.size() - 1)
				catg[item]++;
		}
		testinputData.push_back(conts);
	}
	total_record_number = testinputData.size();
	ifs.close();
	return 0;
}
/* 讀取檔案資料,採取的是c++字串流的讀取方式
 得到結果:inputData 資料來源
		   classes 分類標籤以及個數(first:哺乳類,second:6)
		   rest_number 分類的種類數
*/
int readInput(string filename) {
	ifstream ifs(filename.c_str());
	if (!ifs) {
		cerr << "open inputfile failed!" << endl;
		return -1;
	}
	map catg;
	string line;
	getline(ifs, line);
	string item;
	istringstream strstm(line);
	strstm >> item;
	for (int i = 0; i> item;
		X[i] = item;
	}
	while (getline(ifs, line)) {
		vector conts(ATTR_NUM + 2);
		istringstream strstm(line);
		//strstm.str(line);
		for (int i = 0; i> item;
			conts[i] = item;
			if (i == conts.size() - 1)
				catg[item]++;
		}
		inputData.push_back(conts);
	}
	total_record_number = inputData.size();
	ifs.close();
	map::const_iterator itr = catg.begin();//將catg歸類結果放入classes中
	while (itr != catg.end()) {
		classes.push_back(make_pair(itr->first, itr->second));
		itr++;
	}
	rest_number = classes.size();//標籤分為幾類
	return 0;
}

/*根據inputData作出一個統計stati,統計的是在哪種屬性下的某類的個數。*/
void statistic(vector > &inputData, VEC_STATI &stati) {
	for (int i = 1; isecond).find(rest);
				if (iter == (itr->second).end()) {
					(itr->second).insert(make_pair(rest, 1));
				}
				else {
					iter->second += 1;
				}
			}
		}
		stati.push_back(attr_rest);
	}
}

/*依據某條件作出分枝時,inputData被分成兩部分*/
void splitInput(vector > &inputData, int fitIndex, string cond, vector > &LinputData, vector > &RinputData) {
	for (int i = 0; i > &inputData) {
	for (int i = 0; i < ATTR_NUM + 2; ++i) {
		for (int j = 0; j < inputData.size(); ++j) {
			cout << inputData[j][i] << "\t";
		}
	}cout << endl;
}
void printStati(VEC_STATI &stati) {
	for (int i = 0; ifirst;
			MAP_REST_COUNT::const_iterator iter = (itr->second).begin();
			while (iter != (itr->second).end()) {
				cout << "\t" << iter->first << "\t" << iter->second;
				iter++;
			}
			itr++;
			cout << endl;
		}
		cout << endl;
	}
}

void split(node *root, vector > &inputData, vector > classes) {
	//root->printInfo();
	root->record_number = inputData.size();
	VEC_STATI stati;
	statistic(inputData, stati);
	//printStati(stati);
	//for(int i=0;i > fitleftclasses;//左樹的分類標籤以及個數
	vector > fitrightclasses;//右樹的分類標籤以及個數
	int fitleftnumber;//左樹記錄數
	int fitrightnumber;
	for (int i = 0; ifirst;     //判定的條件,即到達左孩子的條件,屬性
											   //cout<<"cond 為"< > leftclasses(classes);     //左孩子節點上類別、及對應的數目
			vector > rightclasses(classes);    //右孩子節點上類別、及對應的數目
			int leftnumber = 0;       //左孩子節點上包含的類別數目
			int rightnumber = 0;      //右孩子節點上包含的類別數目
			for (int j = 0; jsecond).find(rest);//
				if (iter2 == (itr->second).end()) {      //沒找到,則對應類別以及類別樹就全部在右樹
					leftclasses[j].second = 0;
					rightnumber += rightclasses[j].second;
				}
				else {       //找到,則右邊樹對應的種類以及個數就是總體的減去左邊的種類數
					leftclasses[j].second = iter2->second;
					leftnumber += leftclasses[j].second;
					rightclasses[j].second -= (iter2->second);
					rightnumber += rightclasses[j].second;
				}
			}
			/**if(leftnumber==0 || rightnumber==0){
			cout<<"左右有一邊為空"<cond<size)++;
		travel = travel->parent;
	}

	node *LChild = new node(root);        //建立左右孩子
	node *RChild = new node(root);
	root->leftchild = LChild;
	root->rightchild = RChild;
	int maxLcount = 0;
	int maxRcount = 0;
	string Ldicision, Rdicision;
	for (int i = 0; imaxLcount) {
			maxLcount = fitleftclasses[i].second;
			Ldicision = fitleftclasses[i].first;
		}
		if (fitrightclasses[i].second>maxRcount) {
			maxRcount = fitrightclasses[i].second;
			Rdicision = fitrightclasses[i].first;
		}
	}
	LChild->decision = Ldicision;
	RChild->decision = Rdicision;
	//LChild->precision = 1.0*maxLcount / fitleftnumber;
	//RChild->precision = 1.0*maxRcount / fitrightnumber;

	/*遞迴對左右孩子進行分裂*/
	vector > LinputData, RinputData;
	splitInput(inputData, fitIndex, fitCond, LinputData, RinputData);
	//cout<<"左邊inputData行數:"< > &testinputData) {
	int i=0;
	int fitIndex;
	total_record_number = testinputData.size();
	node *LChild= new node(root);
	node *RChild= new node(root);
	vector > LinputData, RinputData;
	LChild =root->leftchild;
	RChild = root->rightchild;
	if (root->leftchild == NULL)
		return;
	string cond = root->cond;//分支條件是字串:屬性=屬性下的分類,一下是對字串的操作
	string::size_type pos = cond.find("=");
	string pre = cond.substr(0, pos);//將字串前0-pos的位置的子字串賦予pre
	string post = cond.substr(pos + 1);//在此節點上的分支
	for(int index=0;indexrecord_number = LinputData.size();
	RChild->record_number = RinputData.size();
	//printinputData(LinputData);
	//printinputData(RinputData);
	/*計算正確率*/
	for (int j = 0; j < LinputData.size(); ++j) {
		string rest = LinputData[j][ATTR_NUM + 1];//左樹這一行的標籤
		if (rest == LChild->decision)
			i++;
	}
	if (LChild->record_number == 0)
		LChild->precision = 0;
	else
		LChild->precision=1.0*i/LChild->record_number;
	i = 0;
	for (int j = 0; j < RinputData.size(); ++j) {
		string rest = RinputData[j][ATTR_NUM + 1];//右樹這一行的標籤
		if (rest == RChild->decision)
			i++;
	}
	if (RChild->record_number == 0)
		RChild->precision=0;
	else
		RChild->precision = 1.0*i/RChild->record_number;
	if(LChild->leftchild!=NULL)
		pruneprecision(LChild,LinputData);

	if(RChild->leftchild!=NULL)
		pruneprecision(RChild, RinputData);
}
/*計運算元樹的誤差代價*/
double calR2(node *root) {
	if (root->leftchild == NULL)//葉子結點是沒有左右子樹的
		return (1 - root->precision)*root->record_number / total_record_number;
	else
		return calR2(root->leftchild) + calR2(root->rightchild);
}
/*層次遍歷樹,給節點標上序號*/
void index(node *root) {
	int i = 1;
	queue que;
	que.push(root);
	while (!que.empty()) {
		node* n = que.front();
		que.pop();
		n->index = i++;
		if (n->leftchild != NULL) {
			que.push(n->leftchild);
			que.push(n->rightchild);
		}
	}
}



/*層次遍歷樹,給節點標上序號。同時計算alpha*/
void calalpha(node *root, priority_queue, MyCompare> &pq) {
	int i = 1;
	queue que;
	que.push(root);
	while (!que.empty()) {
		node* n = que.front();
		que.pop();
		n->index = i++;
		if (n->leftchild != NULL) {
			que.push(n->leftchild);
			que.push(n->rightchild);
			//計算表面誤差率的增量
			double r1 = (1 - n->precision)*n->record_number / total_record_number;      //節點的誤差代價
			double r2 = calR2(n);
			n->alpha = (r1 - r2) / (n->size - 1);
			pq.push(MyTriple(n->alpha, n->size, n->index));
		}
	}
}

/*剪枝*/
void prune(node *root, priority_queue, MyCompare> &pq) {
	MyTriple triple = pq.top();
	int i = triple.third;
	queue que;
	que.push(root);
	while (!que.empty()) {
		node* n = que.front();
		que.pop();
		if (n->index == i) {
			cout << "將要剪掉" << i << "的左右子樹" << endl;
			n->leftchild = NULL;
			n->rightchild = NULL;
			int s = n->size - 1;
			node *trav = n;
			while (trav != NULL) {
				trav->size -= s;
				trav = trav->parent;
			}
			break;
		}
		else if (n->leftchild != NULL) {
			que.push(n->leftchild);
			que.push(n->rightchild);
		}
	}
}

void test(string filename, node *root,int labels) {
	ifstream ifs(filename.c_str());
	if (!ifs) {
		cerr << "open inputfile failed!" << endl;
		return;
	}
	string line;
	getline(ifs, line);
	string item;
	istringstream strstm(line);     //跳過第一行
	map independent;       //自變數,即分類的依據
	while (getline(ifs, line)) {
		istringstream strstm(line);
		//strstm.str(line);
		strstm >> item;
		cout << item << "\t";
		for (int i = 0; i> item;
			independent[X[i]] = item;
		}
		node *trav = root;
		while (trav != NULL) {
			if (trav->leftchild == NULL) {
				if (labels >0) {
					cout << (trav->decision) << "\t置信度:" << (trav->precision) << endl;
					break;
				}
				else
					cout << (trav->decision) << endl;
			}
			string cond = trav->cond;//分支條件是字串:屬性=屬性下的分類,一下是對字串的操作
			string::size_type pos = cond.find("=");
			string pre = cond.substr(0, pos);//將字串前0-pos的位置的子字串賦予pre
			string post = cond.substr(pos + 1);
			if (independent[pre] == post)
				trav = trav->leftchild;
			else
				trav = trav->rightchild;
		}
	}
	ifs.close();
}

int main() {
	string inputFile = "watermelon.txt";
	readInput(inputFile);
	VEC_STATI stati,teststati;        //最原始的統計
	statistic(inputData, stati);
	//  for(int i=0;iprintTree();
	cout << "剪枝前使用該決策樹最多進行" << root->size - 1 << "次條件判斷" << endl;
	string testFile = "testwatermelon.txt";
	readtestInput(testFile);
	test(testFile, root,0);
	/*進行剪枝*/
	pruneprecision(root,testinputData);
	//root->printTree();
	priority_queue, MyCompare> pq;
	calalpha(root,pq);
	/*//檢驗一個是不是表面誤差增量最小的被剪掉了
	while(!pq.empty()){
	MyTriple triple=pq.top();
	pq.pop();
	cout<size - 1 << "次條件判斷" << endl;
	test(testFile, root,1);
	/*priority_queue pq;
	calalpha(root, pq);
	root->printTree();
	prune(root, pq);
	cout << "剪枝後使用該決策樹最多進行" << root->size - 1 << "次條件判斷" << endl;
	test(testFile, root);*/
	system("pause");
	return 0;
}