機器學習:決策樹cart演算法在分類與迴歸的應用(上)
阿新 • • 發佈:2019-02-07
#include #include #include #include #include #include #include #include #include #include #include using namespace std; //置信水平取0.95時的卡方表 const double CHI[18] = { 0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962 }; /*根據多維陣列計算卡方值*/ template double cal_chi(Comparable **arr, int row, int col) { vector rowsum(row); vector colsum(col); Comparable totalsum = static_cast(0);//強制將0轉換為Comparable型 //cout<<"observation"<right.first; } }; /* 下面這三個資料結構是來存在在哪種屬性下的某一類的個數*/ typedef map MAP_REST_COUNT; typedef map MAP_ATTR_REST; typedef vector VEC_STATI; const int ATTR_NUM = 6; //自變數的維度 vector X(ATTR_NUM); int rest_number; //因變數的種類數,即類別數 vector > classes; //把類別、對應的記錄數存放在一個數組中 int total_record_number; //總的記錄數 vector > inputData; //原始輸入資料 vector > testinputData; //測試輸入資料 class node { public: node* parent; //父節點 node* leftchild; //左孩子節點 node* rightchild; //右孩子節點 string cond; //分枝條件 string decision; //在該節點上作出的類別判定 double precision; //判定的正確率 int record_number; //該節點上涵蓋的記錄個數 int size; //子樹包含的葉子節點的數目 int index; //層次遍歷樹,給節點標上序號 double alpha; //表面誤差率的增加量 node() { parent = NULL; leftchild = NULL; rightchild = NULL; precision = 0.0; record_number = 0; size = 1; index = 0; alpha = 1.0; } node(node* p) { parent = p; leftchild = NULL; rightchild = NULL; precision = 0.0; record_number = 0; size = 1; index = 0; alpha = 1.0; } node(node* p, string c, string d) :cond(c), decision(d) { parent = p; leftchild = NULL; rightchild = NULL; precision = 0.0; record_number = 0; size = 1; index = 0; alpha = 1.0; } void printInfo() { cout << "index:" << index << "\tdecisoin:" << decision << "\tprecision:" << precision << "\tcondition:" << cond << "\tsize:" << size; if (parent != NULL) cout << "\tparent index:" << parent->index; if (leftchild != NULL) cout << "\tleftchild:" << leftchild->index << "\trightchild:" << rightchild->index; cout << endl; } void printTree() { printInfo(); if (leftchild != NULL) leftchild->printTree(); if (rightchild != NULL) rightchild->printTree(); } }; /* 讀取測試檔案資料,採取的是c++字串流的讀取方式 得到結果:testinputData 資料來源 */ int readtestInput(string filename) { ifstream ifs(filename.c_str()); if (!ifs) { cerr << "open inputfile failed!" << endl; return -1; } map catg; string line; getline(ifs, line); string item; istringstream strstm(line); strstm >> item; for (int i = 0; i> item; X[i] = item; } while (getline(ifs, line)) { vector conts(ATTR_NUM + 2); istringstream strstm(line); //strstm.str(line); for (int i = 0; i> item; conts[i] = item; if (i == conts.size() - 1) catg[item]++; } testinputData.push_back(conts); } total_record_number = testinputData.size(); ifs.close(); return 0; } /* 讀取檔案資料,採取的是c++字串流的讀取方式 得到結果:inputData 資料來源 classes 分類標籤以及個數(first:哺乳類,second:6) rest_number 分類的種類數 */ int readInput(string filename) { ifstream ifs(filename.c_str()); if (!ifs) { cerr << "open inputfile failed!" << endl; return -1; } map catg; string line; getline(ifs, line); string item; istringstream strstm(line); strstm >> item; for (int i = 0; i> item; X[i] = item; } while (getline(ifs, line)) { vector conts(ATTR_NUM + 2); istringstream strstm(line); //strstm.str(line); for (int i = 0; i> item; conts[i] = item; if (i == conts.size() - 1) catg[item]++; } inputData.push_back(conts); } total_record_number = inputData.size(); ifs.close(); map::const_iterator itr = catg.begin();//將catg歸類結果放入classes中 while (itr != catg.end()) { classes.push_back(make_pair(itr->first, itr->second)); itr++; } rest_number = classes.size();//標籤分為幾類 return 0; } /*根據inputData作出一個統計stati,統計的是在哪種屬性下的某類的個數。*/ void statistic(vector > &inputData, VEC_STATI &stati) { for (int i = 1; isecond).find(rest); if (iter == (itr->second).end()) { (itr->second).insert(make_pair(rest, 1)); } else { iter->second += 1; } } } stati.push_back(attr_rest); } } /*依據某條件作出分枝時,inputData被分成兩部分*/ void splitInput(vector > &inputData, int fitIndex, string cond, vector > &LinputData, vector > &RinputData) { for (int i = 0; i > &inputData) { for (int i = 0; i < ATTR_NUM + 2; ++i) { for (int j = 0; j < inputData.size(); ++j) { cout << inputData[j][i] << "\t"; } }cout << endl; } void printStati(VEC_STATI &stati) { for (int i = 0; ifirst; MAP_REST_COUNT::const_iterator iter = (itr->second).begin(); while (iter != (itr->second).end()) { cout << "\t" << iter->first << "\t" << iter->second; iter++; } itr++; cout << endl; } cout << endl; } } void split(node *root, vector > &inputData, vector > classes) { //root->printInfo(); root->record_number = inputData.size(); VEC_STATI stati; statistic(inputData, stati); //printStati(stati); //for(int i=0;i > fitleftclasses;//左樹的分類標籤以及個數 vector > fitrightclasses;//右樹的分類標籤以及個數 int fitleftnumber;//左樹記錄數 int fitrightnumber; for (int i = 0; ifirst; //判定的條件,即到達左孩子的條件,屬性 //cout<<"cond 為"< > leftclasses(classes); //左孩子節點上類別、及對應的數目 vector > rightclasses(classes); //右孩子節點上類別、及對應的數目 int leftnumber = 0; //左孩子節點上包含的類別數目 int rightnumber = 0; //右孩子節點上包含的類別數目 for (int j = 0; jsecond).find(rest);// if (iter2 == (itr->second).end()) { //沒找到,則對應類別以及類別樹就全部在右樹 leftclasses[j].second = 0; rightnumber += rightclasses[j].second; } else { //找到,則右邊樹對應的種類以及個數就是總體的減去左邊的種類數 leftclasses[j].second = iter2->second; leftnumber += leftclasses[j].second; rightclasses[j].second -= (iter2->second); rightnumber += rightclasses[j].second; } } /**if(leftnumber==0 || rightnumber==0){ cout<<"左右有一邊為空"<cond<size)++; travel = travel->parent; } node *LChild = new node(root); //建立左右孩子 node *RChild = new node(root); root->leftchild = LChild; root->rightchild = RChild; int maxLcount = 0; int maxRcount = 0; string Ldicision, Rdicision; for (int i = 0; imaxLcount) { maxLcount = fitleftclasses[i].second; Ldicision = fitleftclasses[i].first; } if (fitrightclasses[i].second>maxRcount) { maxRcount = fitrightclasses[i].second; Rdicision = fitrightclasses[i].first; } } LChild->decision = Ldicision; RChild->decision = Rdicision; //LChild->precision = 1.0*maxLcount / fitleftnumber; //RChild->precision = 1.0*maxRcount / fitrightnumber; /*遞迴對左右孩子進行分裂*/ vector > LinputData, RinputData; splitInput(inputData, fitIndex, fitCond, LinputData, RinputData); //cout<<"左邊inputData行數:"< > &testinputData) { int i=0; int fitIndex; total_record_number = testinputData.size(); node *LChild= new node(root); node *RChild= new node(root); vector > LinputData, RinputData; LChild =root->leftchild; RChild = root->rightchild; if (root->leftchild == NULL) return; string cond = root->cond;//分支條件是字串:屬性=屬性下的分類,一下是對字串的操作 string::size_type pos = cond.find("="); string pre = cond.substr(0, pos);//將字串前0-pos的位置的子字串賦予pre string post = cond.substr(pos + 1);//在此節點上的分支 for(int index=0;indexrecord_number = LinputData.size(); RChild->record_number = RinputData.size(); //printinputData(LinputData); //printinputData(RinputData); /*計算正確率*/ for (int j = 0; j < LinputData.size(); ++j) { string rest = LinputData[j][ATTR_NUM + 1];//左樹這一行的標籤 if (rest == LChild->decision) i++; } if (LChild->record_number == 0) LChild->precision = 0; else LChild->precision=1.0*i/LChild->record_number; i = 0; for (int j = 0; j < RinputData.size(); ++j) { string rest = RinputData[j][ATTR_NUM + 1];//右樹這一行的標籤 if (rest == RChild->decision) i++; } if (RChild->record_number == 0) RChild->precision=0; else RChild->precision = 1.0*i/RChild->record_number; if(LChild->leftchild!=NULL) pruneprecision(LChild,LinputData); if(RChild->leftchild!=NULL) pruneprecision(RChild, RinputData); } /*計運算元樹的誤差代價*/ double calR2(node *root) { if (root->leftchild == NULL)//葉子結點是沒有左右子樹的 return (1 - root->precision)*root->record_number / total_record_number; else return calR2(root->leftchild) + calR2(root->rightchild); } /*層次遍歷樹,給節點標上序號*/ void index(node *root) { int i = 1; queue que; que.push(root); while (!que.empty()) { node* n = que.front(); que.pop(); n->index = i++; if (n->leftchild != NULL) { que.push(n->leftchild); que.push(n->rightchild); } } } /*層次遍歷樹,給節點標上序號。同時計算alpha*/ void calalpha(node *root, priority_queue, MyCompare> &pq) { int i = 1; queue que; que.push(root); while (!que.empty()) { node* n = que.front(); que.pop(); n->index = i++; if (n->leftchild != NULL) { que.push(n->leftchild); que.push(n->rightchild); //計算表面誤差率的增量 double r1 = (1 - n->precision)*n->record_number / total_record_number; //節點的誤差代價 double r2 = calR2(n); n->alpha = (r1 - r2) / (n->size - 1); pq.push(MyTriple(n->alpha, n->size, n->index)); } } } /*剪枝*/ void prune(node *root, priority_queue, MyCompare> &pq) { MyTriple triple = pq.top(); int i = triple.third; queue que; que.push(root); while (!que.empty()) { node* n = que.front(); que.pop(); if (n->index == i) { cout << "將要剪掉" << i << "的左右子樹" << endl; n->leftchild = NULL; n->rightchild = NULL; int s = n->size - 1; node *trav = n; while (trav != NULL) { trav->size -= s; trav = trav->parent; } break; } else if (n->leftchild != NULL) { que.push(n->leftchild); que.push(n->rightchild); } } } void test(string filename, node *root,int labels) { ifstream ifs(filename.c_str()); if (!ifs) { cerr << "open inputfile failed!" << endl; return; } string line; getline(ifs, line); string item; istringstream strstm(line); //跳過第一行 map independent; //自變數,即分類的依據 while (getline(ifs, line)) { istringstream strstm(line); //strstm.str(line); strstm >> item; cout << item << "\t"; for (int i = 0; i> item; independent[X[i]] = item; } node *trav = root; while (trav != NULL) { if (trav->leftchild == NULL) { if (labels >0) { cout << (trav->decision) << "\t置信度:" << (trav->precision) << endl; break; } else cout << (trav->decision) << endl; } string cond = trav->cond;//分支條件是字串:屬性=屬性下的分類,一下是對字串的操作 string::size_type pos = cond.find("="); string pre = cond.substr(0, pos);//將字串前0-pos的位置的子字串賦予pre string post = cond.substr(pos + 1); if (independent[pre] == post) trav = trav->leftchild; else trav = trav->rightchild; } } ifs.close(); } int main() { string inputFile = "watermelon.txt"; readInput(inputFile); VEC_STATI stati,teststati; //最原始的統計 statistic(inputData, stati); // for(int i=0;iprintTree(); cout << "剪枝前使用該決策樹最多進行" << root->size - 1 << "次條件判斷" << endl; string testFile = "testwatermelon.txt"; readtestInput(testFile); test(testFile, root,0); /*進行剪枝*/ pruneprecision(root,testinputData); //root->printTree(); priority_queue, MyCompare> pq; calalpha(root,pq); /*//檢驗一個是不是表面誤差增量最小的被剪掉了 while(!pq.empty()){ MyTriple triple=pq.top(); pq.pop(); cout<size - 1 << "次條件判斷" << endl; test(testFile, root,1); /*priority_queue pq; calalpha(root, pq); root->printTree(); prune(root, pq); cout << "剪枝後使用該決策樹最多進行" << root->size - 1 << "次條件判斷" << endl; test(testFile, root);*/ system("pause"); return 0; }