1. 程式人生 > >CART分類與迴歸樹的原理與實現

CART分類與迴歸樹的原理與實現

// cart.cpp : 定義控制檯應用程式的入口點。
//

#include "stdafx.h"
#include<vector>
#include<set>
#include<algorithm>
#include<iostream>
#include<iterator>
#include<fstream>
#include<string>
#include<map>
/*******************************************/
/************author Marshall****************/
/**********date 2015.10.3*******************/
/**************version 1.0******************/
/************copyright reserved*************/
/*******************************************/
using namespace std;



class cart
{
private:
	vector<int>nums_of_value_each_discreteAttri;
	int num_of_continuousAttri;
	int ContinuousAttriNums;
	int labelNums;//how many kinds of label
	unsigned int CL_max_height;
	//double miniumginigain;//not need,we have prune method

	//define the record
	class Record
	{
	public:
		vector<int>discrete_attri;//for each discrete attribute,it's value can be 0,1...increased by 1
		vector<double>continuous_attti;
		int label;//0,1,2...
	};

	//define the node
	struct CartNode
	{
		vector<int>remianDiscreteAttriID;
		int selectedAttriID;
		vector<int>selectedDiscreteAttriValues;
		bool isSelectedAttriIDDiscrete;
		double continuousAttriPartitionValue;//
		int label;//if the record drop in this node,its' label should be
		int height;//current node's height
		vector<int>labelcount;//a counter for the records' label that current node holds
		double alpha;//for nonleaf,for prune
		int record_number;//該節點上涵蓋的記錄個數
		CartNode*lnode, *rnode;
		CartNode()
		{
			label = -1;
			selectedAttriID = -1;
			isSelectedAttriIDDiscrete = true;
			lnode = rnode = NULL;
			record_number = 0;
		}
	};
	CartNode*root;


	//double threshold;

private:
	//calculate gini index,for classify
	double calGiniIndex(vector<int>&subdatasetbyID, const vector<Record>*dataset, CartNode*node = NULL);
	double calSquaredresiduals();//calculate squaredresiduals,for regression
	void CL_split_dataset();
	void RE_split_dataset();
	void CL_trim(const vector<Record>*validationdataset);
	void RE_trim();
	//void make_discrete();
	//pair.first is majority label in subdataset,pair.second is it's number
	int allthesame(vector<int>&subdatasetbyID, const vector<Record>*dataset);
	/*如果某特徵取值有3個,那麼二分序列組合就有3種,4個取值就有7種組合,5個取值就有15種組合*/
	vector<pair<vector<int>, vector<int>>>make_two_heap(const int kk);
	pair<vector<int>, vector<int>>split_dataset(const int&selectedDiscreteAttriID,
		vector<int>&selected, const vector<int>&subdatasetbyID, const vector<Record>*dataset);
	pair<vector<int>, vector<int>>split_dataset(const int&selectedContiuousAttriID,
		const double partition, const vector<int>&subdatasetbyID, const vector<Record>*dataset);
	CartNode* copytree(CartNode*src, CartNode*dst);//deepcopy of a tree,dst should be NUll
	void copynode(CartNode*src, CartNode*dst);
	void cal_alpha(CartNode*node);
	vector<CartNode*>getLeaf(CartNode*node);
	void destroyTree(CartNode*node);
	int labelNode(CartNode*node);
	void create_root();
	void build_tree_classify(vector<int>&subdatasetbyID,
		CartNode*node, const vector<Record>*dataset);
	void build_tree_regression();
public:
	void load_adult_dataset();
	int CART_classify(const Record dataset, CartNode*root = NULL);
	void CART_regression();
	void CART_trian(const vector<Record>*dataset, const vector<Record>*validationdataset);
	void CART_trian()
	{
		CART_trian(traindataset, validatedataset);
	}
	void set_paras();
	~cart()
	{
		destroyTree(root);
		if (traindataset != NULL)
			delete traindataset;
		if (validatedataset != NULL)
			delete validatedataset;
	}
	vector<Record>*traindataset;//as it's name
	vector<Record>*validatedataset;
	vector<Record>*testdataset;
	void test(CartNode*node);
	void test();
};
void cart::test(CartNode*node)
{
	int errorNum = 0;
	for (int j = 0; j < testdataset->size(); j++)
	{
		errorNum += CART_classify((*testdataset)[j], node) == (*testdataset)[j].label ? 0 : 1;
	}
	cout << "測試集上的錯誤率為" << double(errorNum) / testdataset->size();

}

void cart::test()
{
	test(this->root);

}



void cart::set_paras()
{
	CL_max_height = 6;


}
void cart::CART_trian(const vector<Record>*dataset, const vector<Record>*validationdataset)
{
	create_root();
	set_paras();
	vector<int>subset;
	for (int i = 0; i < dataset->size(); i++)
		subset.push_back(i);
	build_tree_classify(subset, root, dataset);
	CL_trim(validationdataset);
}


void cart::destroyTree(CartNode*treeroot)
{
	_ASSERTE(treeroot != NULL);
	vector<CartNode*>pool, que;
	que.push_back(treeroot);
	while (!que.empty())
	{
		CartNode*node = que.back();
		que.pop_back();
		pool.push_back(node);
		if (node->lnode != NULL)
		{
			_ASSERTE(node->rnode != NULL);
			pool.push_back(node->lnode);
			pool.push_back(node->rnode);
		}
	}
	for (int i = 0; i < pool.size(); i++)
		delete pool[i];
}

void cart::copynode(CartNode*src, CartNode*dst)
{
	_ASSERTE(dst != NULL);
	_ASSERTE(src != NULL);
	dst->alpha = src->alpha;
	dst->continuousAttriPartitionValue = src->continuousAttriPartitionValue;
	dst->height = src->height;
	dst->isSelectedAttriIDDiscrete = src->isSelectedAttriIDDiscrete;
	dst->label = src->label;
	dst->labelcount = src->labelcount;
	dst->record_number = src->record_number;
	dst->remianDiscreteAttriID = src->remianDiscreteAttriID;
	dst->selectedAttriID = src->selectedAttriID;
	dst->selectedDiscreteAttriValues = src->selectedDiscreteAttriValues;

}

//implementation of tree copy
cart::CartNode* cart::copytree(CartNode*Srctreeroot, CartNode*Dsttreeroot)
{
	_ASSERTE(Dsttreeroot == NULL);
	_ASSERTE(Srctreeroot != NULL);

	vector<CartNode*>pool, parentpool;
	Dsttreeroot = new CartNode;
	copynode(Srctreeroot, Dsttreeroot);
	if (Srctreeroot->lnode == NULL)
	{
		_ASSERTE(Srctreeroot->rnode == NULL);
		return Dsttreeroot;
	}
	pool.push_back(Srctreeroot->lnode);
	pool.push_back(Srctreeroot->rnode);
	parentpool.push_back(Dsttreeroot);

	bool lnodeflag = false;
	while (!pool.empty())
	{
		CartNode*node = pool.back();
		pool.pop_back();
		CartNode*newnode = new CartNode;
		copynode(node, newnode);
		if (!lnodeflag)
			parentpool.back()->rnode = newnode;
		else
			parentpool.back()->lnode = newnode;
		if (node->lnode != NULL)
		{
			_ASSERTE(node->rnode != NULL);
			if (lnodeflag)
				parentpool.pop_back();
			lnodeflag = false;
			pool.push_back(node->lnode);
			pool.push_back(node->rnode);

			parentpool.push_back(newnode);
		}
		else
		{
			if (lnodeflag)
				parentpool.pop_back();
			else
				lnodeflag = !lnodeflag;
		}
	}
	_ASSERTE(parentpool.empty());
	_ASSERTE(Dsttreeroot);
	return Dsttreeroot;
}

int cart::CART_classify(const Record rd, CartNode*treeroot)
{
	if (treeroot == NULL)
		treeroot = this->root;
	CartNode*node = treeroot;
	while (true)
	{
		if (node->lnode == NULL)
		{
			_ASSERTE(node->rnode == NULL);
			return node->label;
		}
		if (node->isSelectedAttriIDDiscrete)
		{
			if (find(node->selectedDiscreteAttriValues.begin(),
				node->selectedDiscreteAttriValues.end(),
				rd.discrete_attri[node->selectedAttriID])
				== node->selectedDiscreteAttriValues.end())
			{
				node = node->rnode;
			}
			else
			{
				node = node->lnode;
			}
		}
		else
		{
			if (rd.continuous_attti[node->selectedAttriID] >= node->continuousAttriPartitionValue)
			{
				node = node->rnode;
			}
			else
			{
				node = node->lnode;
			}
		}
	}
	//should not run here
	_ASSERTE(false);
}


void cart::CL_trim(const vector<Record>*validationdataset)
{
	vector<CartNode*>candidateBestTree;
	CartNode*curretroot = root;
	while (curretroot->lnode != NULL)//&&root->rnode!=NULL
	{
		vector<CartNode*>pool;
		pool.push_back(curretroot);
		double min_alpha = 10000000;
		CartNode*tobecut = NULL;
		while (!pool.empty())
		{
			CartNode*node = pool.back();
			pool.pop_back();
			if (node->lnode != NULL)
			{
				_ASSERTE(node->rnode != NULL);
				cal_alpha(node);
				if (node->alpha < min_alpha)
				{
					min_alpha = node->alpha;
					tobecut = node;
				}
				pool.push_back(node->rnode);
				pool.push_back(node->lnode);
			}
		}
		_ASSERTE(tobecut != NULL);
		//then delete tobecut's child and son node
		vector<CartNode*>alltodel, temppool;
		temppool.push_back(tobecut);
		while (!temppool.empty())
		{
			CartNode*nn = temppool.back();
			temppool.pop_back();
			alltodel.push_back(nn);
			if (nn->lnode != NULL)
			{
				_ASSERTE(nn->rnode != NULL);
				temppool.push_back(nn->lnode);
				temppool.push_back(nn->rnode);
			}
		}
		alltodel.erase(find(alltodel.begin(), alltodel.end(), tobecut));
		for (int i = 0; i < alltodel.size(); i++)
			delete alltodel[i];
		tobecut->lnode = tobecut->rnode = NULL;



		candidateBestTree.push_back(curretroot);
		CartNode*nextroot = NULL;
		nextroot = copytree(curretroot, nextroot);
		_ASSERTE(nextroot);
		curretroot = nextroot;
	}

	//get the best tree
	int minError = validationdataset->size();
	CartNode*besttree = NULL;
	int th = -1;
	vector<int>candidateBestTreeErrorNums;
	for (int i = 0; i < candidateBestTree.size(); i++)
	{
		int errorNum = 0;
		for (int j = 0; j < validationdataset->size(); j++)
		{
			errorNum += CART_classify((*validationdataset)[j],
				candidateBestTree[i]) == (*validationdataset)[j].label ? 0 : 1;
		}
		//error /= (*validationdataset).size();
		candidateBestTreeErrorNums.push_back(errorNum);
		if (errorNum < minError)
		{
			minError = errorNum;
			th = i;
		}
	}

	test(candidateBestTree[th]);

	double SE = sqrt(double(minError*(validationdataset->size() - minError)) / validationdataset->size());
	for (int i = candidateBestTree.size() - 1; i >= 0; i--)
	{
		if (candidateBestTreeErrorNums[i] <= minError + SE)
		{
			besttree = candidateBestTree[i];
			th = i;
			break;
		}
	}
	candidateBestTree.erase(candidateBestTree.begin() + th);
	for (int i = 0; i < candidateBestTree.size(); i++)
		destroyTree(candidateBestTree[i]);
	_ASSERTE(besttree != NULL);
	root = besttree;
	cout << "剪枝後在驗證集上的錯誤為" << (double)candidateBestTreeErrorNums[th] / validationdataset->size() << endl;
}


void cart::cal_alpha(CartNode*node)
{
	_ASSERTE(node->lnode != NULL&&node->rnode != NULL);
	int max_nodelabel = -1;
	for (int i = 0; i < labelNums; i++)
	{
		if (node->labelcount[i] > max_nodelabel)
		{
			max_nodelabel = node->labelcount[i];
		}
	}
	double Rt = double(max_nodelabel) / node->record_number*node->record_number / traindataset->size();
	double RTt = 0;
	vector<CartNode*>leafpool = getLeaf(node);
	for (int i = 0; i < leafpool.size(); i++)
	{
		RTt += double(leafpool[i]->record_number - leafpool[i]->labelcount[leafpool[i]->label]) /
			traindataset->size();
	}
	node->alpha = (Rt - RTt) / (leafpool.size() - 1);
}


vector<cart::CartNode*>cart::getLeaf(CartNode*node)
{
	vector<CartNode*>leafpool, que;
	que.push_back(node);
	while (!que.empty())
	{
		CartNode*nn = que.back();
		que.pop_back();
		if (nn->lnode != NULL)
			que.push_back(nn->lnode);
		else
		{
			_ASSERTE(nn->rnode == NULL);
			if (find(leafpool.begin(), leafpool.end(), nn) == leafpool.end())
				leafpool.push_back(nn);
		}

		if (nn->rnode != NULL)
			que.push_back(nn->rnode);
		else
		{
			_ASSERTE(nn->lnode == NULL);
			if (find(leafpool.begin(), leafpool.end(), nn) == leafpool.end())
				leafpool.push_back(nn);
		}
	}
	return leafpool;
}


pair<vector<int>, vector<int>>cart::split_dataset(const int&selectedDiscreteAttriID,
	vector<int>&selected, const vector<int>&subdatasetbyID, const vector<Record>*dataset)
{
	vector<int>aa, bb;
	for (int i = 0; i < subdatasetbyID.size(); i++)
	{
		if (find(selected.begin(), selected.end(), (*dataset)[subdatasetbyID[i]].
			discrete_attri[selectedDiscreteAttriID]) == selected.end())
		{
			bb.push_back(subdatasetbyID[i]);
		}
		else
			aa.push_back(subdatasetbyID[i]);
	}
	return pair<vector<int>, vector<int>>(aa, bb);
}

pair<vector<int>, vector<int>>cart::split_dataset(const int&selectedContiuousAttriID,
	const double partition, const vector<int>&subdatasetbyID, const vector<Record>*dataset)
{
	vector<int>aa, bb;
	for (int i = 0; i < subdatasetbyID.size(); i++)
	{
		if ((*dataset)[subdatasetbyID[i]].continuous_attti[selectedContiuousAttriID] >= partition)
		{
			bb.push_back(subdatasetbyID[i]);
		}
		else
			aa.push_back(subdatasetbyID[i]);
	}
	return pair<vector<int>, vector<int>>(aa, bb);

}
set<set<int>>solu;
void select(set<int>&selected, vector<int>&remain, int toselect)
{
	if (selected.size() == toselect)
	{
		if (solu.find(selected) == solu.end())
		{
			solu.insert(selected);
			//for (set<int>::iterator it = selected.begin(); it != selected.end(); it++)
			//	cout << *it << ",";
			//cout << endl;
		}
		return;
	}
	for (int i = 0; i < remain.size(); i++)
	{
		vector<int> re = remain;
		set<int>se = selected;
		se.insert(re[i]);
		re.erase(re.begin() + i);
		select(se, re, toselect);
	}
}
void Combination(vector<int>remain, int toselect)//組合  
{
	solu.clear();
	set<int>selected;
	select(selected, remain, toselect);
	//cout << "共有" << solu.size() << "種組合" << endl;
}

vector<pair<vector<int>, vector<int>>>cart::make_two_heap(const int kk)
{
	vector<pair<vector<int>, vector<int>>>toret;
	int len = nums_of_value_each_discreteAttri[kk];
	set<set<int>>re;
	vector<int>remain;
	for (int i = 0; i < len; i++)
		remain.push_back(i);
	for (int i = 1; i < len / 2 + 1; i++)
	{
		Combination(vector<int>(remain), i);
		re.insert(solu.begin(), solu.end());
	}
	for (set<set<int>>::iterator it = re.begin(); it != re.end(); it++)
	{
		vector<int>aa, bb;//bb(*it);
		set_difference(it->begin(), it->end(),
			remain.begin(), remain.end(), inserter(aa, aa.begin()));
		bb.insert(bb.begin(), it->begin(), it->end());

		toret.push_back(pair<vector<int>, vector<int>>(aa, bb));
	}
	return toret;
}

void cart::create_root()
{
	if (root == NULL)
	{
		root = new CartNode;
		for (int i = 0; i < nums_of_value_each_discreteAttri.size(); i++)
			root->remianDiscreteAttriID.push_back(i);
		root->height = 1;

	}
}

int cart::allthesame(vector<int>&subdatasetbyID, const vector<Record>*dataset)
{
	vector<int>count(labelNums);
	int label = ((*dataset)[subdatasetbyID[0]]).label;
	for (int i = 1; i < subdatasetbyID.size(); i++)
		if (((*dataset)[subdatasetbyID[i]]).label != label)
			return -1;
	return label;
}

//build classify tree recursively
void cart::build_tree_classify(vector<int>&subdatasetbyID,
	CartNode*node, const vector<Record>*dataset)
{
	node->record_number = subdatasetbyID.size();
	double basegini = calGiniIndex(subdatasetbyID, dataset, node);
	int currentlabel = allthesame(subdatasetbyID, dataset);
	if (currentlabel >= 0)
	{
		node->label = currentlabel;
		return;
	}
	if (node->height >= CL_max_height)
	{
		node->label = labelNode(node);
		return;
	}
	node->label = labelNode(node);
	double mingini = 10000000000;
	int selected = -1;
	bool isSelectedDiscrete = true;
	vector<int>selectedDiscreteAttriValues;
	pair<vector<int>, vector<int>>splited_subdataset;
	bool lnodeDecreaseDiscreteAttri = false;//is node's lnode's discrete attribute nums decrease
	bool rnodeDecreaseDiscreteAttri = false;



	//for discrete features,calculate giniindex
	for (int i = 0; i < node->remianDiscreteAttriID.size(); i++)
	{
		if (nums_of_value_each_discreteAttri[node->remianDiscreteAttriID[i]] > 2)
		{
			vector<pair<vector<int>, vector<int>>>bipart = make_two_heap(node->remianDiscreteAttriID[i]);
			for (int j = 0; j < bipart.size(); j++)
			{
				pair<vector<int>, vector<int>>two_subdataset = split_dataset(
					node->remianDiscreteAttriID[i], bipart[i].first, subdatasetbyID, dataset);
				if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0)
				{
					double gini1 = calGiniIndex(two_subdataset.first, dataset);
					double gini2 = calGiniIndex(two_subdataset.second, dataset);
					double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1
						+ double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2;
					if (gini < mingini)
					{
						if (bipart[i].first.size() == 1)
							lnodeDecreaseDiscreteAttri = true;
						else
							lnodeDecreaseDiscreteAttri = false;
						if (bipart[i].second.size() == 1)
							rnodeDecreaseDiscreteAttri = true;
						else
							rnodeDecreaseDiscreteAttri = false;
						mingini = gini;
						selected = node->remianDiscreteAttriID[i];
						splited_subdataset = two_subdataset;
						selectedDiscreteAttriValues = bipart[i].first;
					}
				}
			}
		}
		else
		{
			vector<int>aa;
			aa.push_back(0);
			pair<vector<int>, vector<int>>two_subdataset = split_dataset(node->remianDiscreteAttriID[i],
				aa, subdatasetbyID, dataset);
			if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0)
			{
				double gini1 = calGiniIndex(two_subdataset.first, dataset);
				double gini2 = calGiniIndex(two_subdataset.second, dataset);
				double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1
					+ double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2;
				if (gini < mingini)
				{
					mingini = gini;
					selected = node->remianDiscreteAttriID[i];
					splited_subdataset = two_subdataset;
					lnodeDecreaseDiscreteAttri = true;
					rnodeDecreaseDiscreteAttri = true;
					selectedDiscreteAttriValues.clear();
					selectedDiscreteAttriValues.push_back(0);
				}
			}
		}
	}
	// 利用函式物件實現升降排序    
	struct CompNameEx{
		CompNameEx(bool asce, int k, const vector<Record>*dataset) : asce_(asce), kk(k), dataset(dataset)
		{}
		bool operator()(int const& pl, int const& pr)
		{
			return asce_ ? (*dataset)[pl].continuous_attti[kk] < (*dataset)[pr].continuous_attti[kk]
				: (*dataset)[pr].continuous_attti[kk] < (*dataset)[pl].continuous_attti[kk];
			// 《Eff STL》條款21: 永遠讓比較函式對相等的值返回false    
		}
	private:
		bool asce_;
		int kk;
		const vector<Record>*dataset;
	};

	//for continuous features,calculate giniindex
	double partitionpoint;
	for (int i = 0; i < ContinuousAttriNums; i++)
	{
		sort(subdatasetbyID.begin(), subdatasetbyID.end(),
			CompNameEx(true, i, dataset));
		for (int j = 0; j < subdatasetbyID.size() - 1; j++)
		{
			double partition = 0.5*(*dataset)[subdatasetbyID[j]].continuous_attti[i] +
				0.5*(*dataset)[subdatasetbyID[j + 1]].continuous_attti[i];
			pair<vector<int>, vector<int>>two_subdataset =
				split_dataset(i, partition, subdatasetbyID, dataset);
			if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0)
			{
				double gini1 = calGiniIndex(two_subdataset.first, dataset);
				double gini2 = calGiniIndex(two_subdataset.second, dataset);
				double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1
					+ double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2 + log(double(subdatasetbyID.size() - 2) / dataset->size()) / log(2.0);
				if (gini < mingini)
				{
					partitionpoint = partition;
					mingini = gini;
					selected = i;
					isSelectedDiscrete = false;
					splited_subdataset = two_subdataset;
				}
			}
		}
	}

	//we have prune,so regardless of ginigain
	//double ginigain = basegini - mingini;//if not greater than miniumginigain;current node should not grow 

	if (splited_subdataset.first.size() > 0 && splited_subdataset.second.size() > 0)//&&ginigain>miniumginigain)
	{
		CartNode*lchild = new CartNode;
		CartNode*rchild = new CartNode;
		node->lnode = lchild;
		node->rnode = rchild;
		lchild->height = node->height + 1;
		rchild->height = node->height + 1;
		lchild->remianDiscreteAttriID = node->remianDiscreteAttriID;
		rchild->remianDiscreteAttriID = node->remianDiscreteAttriID;
		node->selectedAttriID = selected;
		if (isSelectedDiscrete)
		{
			if (lnodeDecreaseDiscreteAttri)
			{
				lchild->remianDiscreteAttriID.erase(find(lchild->
					remianDiscreteAttriID.begin(), lchild->remianDiscreteAttriID.end(), selected));
			}
			if (rnodeDecreaseDiscreteAttri)
			{
				rchild->remianDiscreteAttriID.erase(find(rchild->
					remianDiscreteAttriID.begin(), rchild->remianDiscreteAttriID.end(), selected));
			}
			node->selectedDiscreteAttriValues = selectedDiscreteAttriValues;
		}
		else
		{
			node->isSelectedAttriIDDiscrete = false;
			node->continuousAttriPartitionValue = partitionpoint;
		}
		//recursively call 	build_tree_classify()
		build_tree_classify(splited_subdataset.first, lchild, dataset);

		build_tree_classify(splited_subdataset.second, rchild, dataset);
	}
}


double cart::calGiniIndex(vector<int>&subdatasetbyID, const vector<Record>*dataset, CartNode*node)
{
	_ASSERTE(subdatasetbyID.size() > 0);
	_ASSERTE(dataset != NULL);
	vector<int>count;
	count.resize(labelNums);
	for (int i = 0; i < subdatasetbyID.size(); i++)
	{
		count[((*dataset)[subdatasetbyID[i]]).label]++;
	}
	if (node != NULL)
	{
		node->labelcount = count;
		node->record_number = subdatasetbyID.size();
	}
	vector<double> probalblity;
	probalblity.resize(labelNums);
	double re = 1;
	for (int i = 0; i < labelNums; i++)
	{
		probalblity[i] = double(count[i]) / subdatasetbyID.size();
		re -= pow(probalblity[i], 2);
	}
	_ASSERTE(re >= 0);
	return re;
}

int cart::labelNode(CartNode*node)
{
	int label = -1;
	double maxpro = 0;
	for (int i = 0; i < labelNums; i++)
	{
		double temppro = double(node->labelcount[i]) / node->record_number;
		temppro /= double(root->labelcount[i]) / root->record_number;
		if (temppro > maxpro)
		{
			maxpro = temppro;
			label = i;
		}
	}
	_ASSERTE(label >= 0);
	return label;
}






int split(const std::string& str, std::vector<std::string>& ret_, std::string sep = ",")
{
	if (str.empty())
	{
		return 0;
	}

	std::string tmp;
	std::string::size_type pos_begin = str.find_first_not_of(sep);
	std::string::size_type comma_pos = 0;

	while (pos_begin != std::string::npos)
	{
		comma_pos = str.find(sep, pos_begin);
		if (comma_pos != std::string::npos)
		{
			tmp = str.substr(pos_begin, comma_pos - pos_begin);
			pos_begin = comma_pos + sep.length();
		}
		else
		{
			tmp = str.substr(pos_begin);
			pos_begin = comma_pos;
		}

		if (!tmp.empty())
		{
			ret_.push_back(tmp);
			tmp.clear();
		}
	}
	return 0;
}





//說明,因為education,workclass,marital-status,occupation,native country屬性太多,不作考慮
void cart::load_adult_dataset()
{
	vector<Record>*traindataset;//as it's name
	vector<Record>*validatedataset;
	string filename = "adult.data";
	ifstream infile(filename.c_str());
	string temp;
	cout << endl;
	int count = 0;
	//vector<vector<std::string>>ss;
	traindataset = new vector < Record > ;
	validatedataset = new vector < Record > ;
	this->traindataset = traindataset;
	this->validatedataset = validatedataset;
	testdataset = new vector < Record > ;
	//Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked

	/*map<string, int>workclass;
	workclass["Private"] = 0;
	workclass["Self-emp-not-inc"] = 1;
	workclass["Self-emp-inc"] = 2;
	workclass["Federal-gov"] = 3;
	workclass["Local-gov"] = 4;
	workclass["State-gov"] = 5;
	workclass["Without-pay"] = 6;
	workclass["Never-worked"] = 7;*/

	//education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th,
	// 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.

	/*map<string, int>education;
	education["Bachelors"] = 0;
	education["Some-college"] = 1;
	education["11th"] = 2;
	education["HS-grad"] = 3;
	education["Prof-school"] = 4;
	education["Assoc-acdm"] = 5;
	education["Assoc-voc"] = 6;
	education["9th"] = 7;
	education["7th-8th"] = 8;
	education["12th"] = 9;
	education["Masters"] = 10;
	education["1st-4th"] = 11;
	education["10th"] = 12;
	education["Doctorate"] = 13;
	education["5th-6th"] = 14;
	education["Preschool"] = 15;
	*/
	//marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed,
	// Married-spouse-absent, Married-AF-spouse.
	/*map<string, int>marital_status;
	marital_status["Married-civ-spouse"] = 0;
	marital_status["Divorced"] = 1;
	marital_status["Never-married"] = 2;
	marital_status["Separated"] = 3;
	marital_status["Widowed"] = 4;
	marital_status["Married-spouse-absent"] = 5;
	marital_status["Married-AF-spouse"] = 6;*/

	//occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, 
	//Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing,
	// Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
	/*map<string, int>occupation;
	occupation["Tech-support"] = 0;
	occupation["Craft-repair"] = 1;
	occupation["Other-service"] = 2;
	occupation["Sales"] = 3;
	occupation["Exec-managerial"] = 4;
	occupation["Prof-specialty"] = 5;
	occupation["Handlers-cleaners"] = 6;
	occupation["Machine-op-inspct"] = 7;
	occupation["Adm-clerical"] = 8;
	occupation["Farming-fishing"] = 9;
	occupation["Transport-moving"] = 10;
	occupation["Priv-house-serv"] = 11;
	occupation["Protective-serv"] = 12;
	occupation["Armed-Forces"] = 13;
	*/

	//relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.

	map<string, int>relationship;
	relationship["Wife"] = 0;
	relationship["Own-child"] = 1;
	relationship["Husband"] = 2;
	relationship["Not-in-family"] = 3;
	relationship["Other-relative"] = 4;
	relationship["Unmarried"] = 5;

	//race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.

	map<string, int>race;
	race["White"] = 0;
	race["Asian-Pac-Islander"] = 1;
	race["Amer-Indian-Eskimo"] = 2;
	race["Other"] = 3;
	race["Black"] = 4;

	//sex: Female, Male.
	map<string, int>sex;
	sex["Female"] = 0;
	sex["Male"] = 1;

	//native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, 
	//Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran,
	// Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, 
	//Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia,
	// Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, 
	//Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
	map<string, int>label;
	label["<=50K"] = 0;
	label[">50K"] = 1;


	while (getline(infile, temp) && count < 7000)
	{

		Record rd;
		rd.continuous_attti.resize(6);
		rd.discrete_attri.resize(3);
		//cout << temp << endl;

		std::vector<std::string>re;
		split(temp, re, std::string(", "));
		bool desert = false;
		if (re.size() == 15)
		{

			/*age: continuous.
			workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
			fnlwgt: continuous.
			education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
			education-num: continuous.
			marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
			occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
			relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
			race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
			sex: Female, Male.
			capital-gain: continuous.
			capital-loss: continuous.
			hours-per-week: continuous.
			native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.*/


			//age continuous
			rd.continuous_attti[0] = atoi(re[0].c_str());

			//workclass discrete
			/*if (workclass.find(re[1]) != workclass.end())
				rd.discrete_attri[0] = workclass[re[1]];
				else
				desert=true;*/

			//fnlwgt: continuous
			rd.continuous_attti[1] = atoi(re[2].c_str());

			//education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
			/*if (education.find(re[3]) != education.end())
				rd.discrete_attri[1] = education[re[3]];
				else
				desert=true;*/

			//education-num: continuous.
			rd.continuous_attti[2] = atoi(re[4].c_str());

			//marital-status
			/*if (marital_status.find(re[5]) != marital_status.end())
				rd.discrete_attri[1] = marital_status[re[5]];
				else
				desert=true;*/

			//relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
			if (relationship.find(re[7]) != relationship.end())
				rd.discrete_attri[0] = relationship[re[7]];
			else
				desert = true;

			//race
			if (race.find(re[8]) != race.end())
				rd.discrete_attri[1] = race[re[8]];
			else
				desert = true;

			//sex
			if (sex.find(re[9]) != sex.end())
				rd.discrete_attri[2] = sex[re[9]];
			else
				desert = true;

			//capital-gain: continuous.
			rd.continuous_attti[3] = atoi(re[10].c_str());

			//capital-loss: continuous.

			rd.continuous_attti[4] = atoi(re[11].c_str());
			//hours-per-week: continuous
			rd.continuous_attti[5] = atoi(re[12].c_str());

			if (label.find(re[14]) != label.end())
				rd.label = label[re[14]];
			else
				desert = true;
			if (!desert)
				if (count < 3500)
				{
					traindataset->push_back(rd);
				}
				else if (count < 4500)
				{
					validatedataset->push_back(rd);
				}
				else
					testdataset->push_back(rd);
		}
		count++;
	}
	ContinuousAttriNums = 6;
	labelNums = 2;
	int aa[3] = { 6, 5, 2 };
	nums_of_value_each_discreteAttri.push_back(6);
	nums_of_value_each_discreteAttri.push_back(5);
	nums_of_value_each_discreteAttri.push_back(2);


}

int _tmain(int argc, _TCHAR* argv[])
{

	cart cart;
	cart.load_adult_dataset();
	cart.CART_trian();
	cart.test();
	system("pause");
	return 0;
}

可能不太完善,大體框架是這樣了,具體細節可能處理不好。歡迎大家指點。