決策樹C4 5分類演算法的C++實現
分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow
也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!
一、前言
當年實習公司佈置了一個任務讓寫一個決策樹,以前並未接觸資料探勘的東西,但作為一個數據挖掘最基本的知識點,還是應該有所理解的。
程式的原始碼可以點選
決策樹是一個預測模型;他代表的是物件屬性與物件值之間的一種對映關係。樹中每個節點表示某個物件,而每個分叉路徑則代表的某個可能的屬性值,而每個葉結點則對應從根節點到該葉節點所經歷的路徑所表示的物件的值。決策樹僅有單一輸出,若欲有複數輸出,可以建立獨立的決策樹以處理不同輸出。 資料探勘中決策樹是一種經常要用到的技術,可以用於分析資料,同樣也可以用來作預測(就像上面的銀行官員用他來預測貸款風險)。從資料產生決策樹的機器學習技術叫做決策樹學習, 通俗說就是決策樹。(來自維基百科)
1986年Quinlan提出了著名的ID3演算法。在ID3演算法的基礎上,1993年Quinlan又提出了C4.5演算法。為了適應處理大規模資料集的需要,後來又提出了若干改進的演算法,其中SLIQ (super-vised learning in quest)和SPRINT (scalable parallelizableinduction of decision trees)是比較有代表性的兩個演算法,此處暫且略過。
本文實現了C4.5的演算法,在ID3的基礎上計算資訊增益,從而更加準確的反應資訊量。其實通俗的說就是構建一棵加權的最短路徑Haffman樹,讓權值最大的節點為父節點。
二、基本概念
下面簡要介紹一下ID3演算法:
ID3演算法的核心是:在決策樹各級結點上選擇屬性時,用資訊增益(information gain)作為屬性的選擇標準,以使得在每一個非葉結點進行測試時,能獲得關於被測試記錄最大的類別資訊。
其具體方法是:檢測所有的屬性,選擇資訊增益最大的屬性產生決策樹結點,由該屬性的不同取值建立分支,再對各分支的子集遞迴呼叫該方法建立決策樹結點的分支,直到所有子集僅包含同一類別的資料為止。最後得到一棵決策樹,它可以用來對新的樣本進行分類。
某屬性的資訊增益按下列方法計算:
資訊熵是夏農提出的,用於描述資訊不純度(不穩定性),其計算公式是Info(D)。
其中:Pi為子集合中不同性(而二元分類即正樣例和負樣例)的樣例的比例;j是屬性A中的索引,D是集合樣本,Dj是D中屬性A上值等於j的樣本集合。
這樣資訊收益可以定義為樣本按照某屬性劃分時造成熵減少的期望,可以區分訓練樣本中正負樣本的能力。資訊增益定義為結點與其子結點的資訊熵之差,公式為Gain(A)。
ID3演算法的優點是:演算法的理論清晰,方法簡單,學習能力較強。其缺點是:只對比較小的資料集有效,且對噪聲比較敏感,當訓練資料集加大時,決策樹可能會隨之改變。
C4.5演算法繼承了ID3演算法的優點,並在以下幾方面對ID3演算法進行了改進:
1) 用資訊增益率來選擇屬性,克服了用資訊增益選擇屬性時偏向選擇取值多的屬性的不足,公式為GainRatio(A);
2) 在樹構造過程中進行剪枝;
3) 能夠完成對連續屬性的離散化處理;
4) 能夠對不完整資料進行處理。
C4.5演算法與其它分類演算法如統計方法、神經網路等比較起來有如下優點:產生的分類規則易於理解,準確率較高。其缺點是:在構造樹的過程中,需要對資料集進行多次的順序掃描和排序,因而導致演算法的低效。此外,C4.5只適合於能夠駐留於記憶體的資料集,當訓練集大得無法在記憶體容納時程式無法執行。
三、資料集
實現的C4.5資料集合如下:
它記錄了再不同的天氣狀況下,是否出去覓食的資料。
四、程式程式碼
程式引入狀態樹作為統計和計算屬性的資料結構,它記錄了每次計算後,各個屬性的統計資料,其定義如下:
[cpp] view plain copy print ?- struct attrItem
- {
- std::vector<int> itemNum; //itemNum[0] = itemLine.size()
- //itemNum[1] = decision num
- set<int> itemLine;
- };
- struct attributes
- {
- string attriName;
- vector<double> statResult;
- map<string, attrItem*> attriItem;
- };
- vector<attributes*> statTree;
決策樹節點資料結構如下:
[cpp] view plain copy print ?
- struct TreeNode
- {
- std::string m_sAttribute;
- int m_iDeciNum;
- int m_iUnDecinum;
- std::vector<TreeNode*> m_vChildren;
- };
程式原始碼如下所示(程式中有詳細註解):
[cpp] view plain copy print ?
- #include "DecisionTree.h"
- int main(int argc, char* argv[]){
- string filename = "source.txt";
- DecisionTree dt ;
- int attr_node = 0;
- TreeNode* treeHead = nullptr;
- set<int> readLineNum;
- vector<int> readClumNum;
- int deep = 0;
- if (dt.pretreatment(filename, readLineNum, readClumNum) == 0)
- {
- dt.CreatTree(treeHead, dt.getStatTree(), dt.getInfos(), readLineNum, readClumNum, deep);
- }
- return 0;
- }
- /*
- * @function CreatTree 預處理函式,負責讀入資料,並生成資訊矩陣和屬性標記
- * @param: filename 檔名
- * @param: readLineNum 可使用行set
- * @param: readClumNum 可用屬性set
- * @return int 返回函式執行狀態
- */
- int DecisionTree::pretreatment(string filename, set<int>& readLineNum, vector<int>& readClumNum)
- {
- ifstream read(filename.c_str());
- string itemline = "";
- getline(read, itemline);
- istringstream iss(itemline);
- string attr = "";
- while(iss >> attr)
- {
- attributes* s_attr = new attributes();
- s_attr->attriName = attr;
- //初始化屬性名
- statTree.push_back(s_attr);
- //初始化屬性對映
- attr_clum[attr] = attriNum;
- attriNum++;
- //初始化可用屬性列
- readClumNum.push_back(0);
- s_attr = nullptr;
- }
- int i = 0;
- //新增具體資料
- while(true)
- {
- getline(read, itemline);
- if(itemline == "" || itemline.length() <= 1)
- {
- break;
- }
- vector<string> infoline;
- istringstream stream(itemline);
- string item = "";
- while(stream >> item)
- {
- infoline.push_back(item);
- }
- infos.push_back(infoline);
- readLineNum.insert(i);
- i++;
- }
- read.close();
- return 0;
- }
- int DecisionTree::statister(vector<vector<string>>& infos, vector<attributes*>& statTree,
- set<int>& readLine, vector<int>& readClumNum)
- {
- //yes的總行數
- int deciNum = 0;
- //統計每一行
- set<int>::iterator iter_end = readLine.end();
- for (set<int>::iterator line_iter = readLine.begin(); line_iter != iter_end; ++line_iter)
- {
- bool decisLine = false;
- if (infos[*line_iter][attriNum - 1] == "yes")
- {
- decisLine = true;
- deciNum++;
- }
- //如果該列未被鎖定並且為屬性列,進行統計
- for (int i = 0; i < attriNum - 1; i++)
- {
- if (readClumNum[i] == 0)
- {
- std::string tempitem = infos[*line_iter][i];
- auto map_iter = statTree[i]->attriItem.find(tempitem);
- //沒有找到
- if (map_iter == (statTree[i]->attriItem).end())
- {
- //新建
- attrItem* attritem = new attrItem();
- attritem->itemNum.push_back(1);
- decisLine ? attritem->itemNum.push_back(1) : attritem->itemNum.push_back(0);
- attritem->itemLine.insert(*line_iter);
- //建立屬性名->item對映
- (statTree[i]->attriItem)[tempitem] = attritem;
- attritem = nullptr;
- }
- else