基於C++語言的決策樹實現
感覺好久都沒有寫過程式了,一直上課沒有時間。最近有點空,然後就寫了下西瓜書中的決策樹的實現。由於本人才疏學淺,採用的實現方式和資料結構可能不合理,沒有考慮程式碼的複雜度和時間複雜度等等,僅寫下自己的實現想法(大神們就打擾了)。該程式是基於C++語言來實現的,演算法就是西瓜書上面的實現演算法,採用最簡單的ID3演算法,用資訊增益來選擇最優劃分,進而進行決策樹的實現(沒有對決策樹進行剪枝操作,以後有時間再改進)。
1、基本概念
決策樹的詳細演算法我就不介紹了,具體參看西瓜書。在這裡,我想寫一下演算法的大體。首先說明一下屬性、屬性值、屬性集、類別、訓練集的概念。
1、屬性:屬性就是用來描述物體特徵的量。比如:描述一個西瓜可以用色澤、根蒂、敲聲、紋理、臍部、觸感等等屬性來說明一個西瓜是說明樣子的。
2、屬性值:屬性值是屬性的取值。比如:一個西瓜的色澤可以用青綠、烏黑、淺白來描述,這些就是西瓜的色澤屬性的屬性值。
3、屬性集:屬性集是物體所有屬性的集合。顧名思義,就是上述列出的屬性打包在一起組成的集合。
4、類別:物體的種類,決策樹的目的就是根據屬性來對物體進行分類。比如:西瓜有好瓜和壞瓜。
5、訓練集:訓練集就是樣本集,所有的學習都需要通過資料來形成,決策樹也是一樣的。需要通過樣本來形成一棵決策樹。訓練集必須包括物體各個屬性的屬性值和類別。
然後說明下演算法的三種情況:
設訓練集為D,屬性集為A,需要生成樹的節點。
1、若訓練集D中的全部樣本的類別都是一樣的,比如都是好瓜或者都是壞瓜。這時候將節點的值標記為樣本的類別。
2、若屬性集A中的屬性為空,即樹的分支已經到底了,樣本的所有屬性已經用完,這時候需要將節點標記為葉節點。或者有訓練集D中的所有樣本,在屬性集A上的取值是一樣的。比如屬性集A={色澤、根蒂、敲聲},而訓練集D={{青綠、蜷縮、濁響、清晰、凹陷、硬滑、好瓜}、{青綠、蜷縮、濁響、稍糊、稍凹、軟粘、好瓜}、{青綠、蜷縮、濁響、模糊、凹陷、軟粘、壞瓜}},這時可以看出,訓練集D在屬性集A上的取值都為:青綠、蜷縮、濁響。這兩種情況需要將節點標記為葉節點,葉節點的屬性值為訓練集中類別最多的類。如上述例子,就需要將葉節點標記為好瓜。
3、在第三步之前需要先從屬性集A中選擇一個最優劃分屬性a,如何選擇最優屬性a,下面再說明。
假設選擇最優屬性a的值的集合為{b1、b2、b3}。例如,假設選擇的屬性是a=色澤,而其對應的屬性集為b={青綠、烏黑、淺白}。
這時候,我們需要對每個屬性值從訓練集D中劃分出屬性a的取值為b1(b2、b3)的子集c。例如,我們的訓練集D={{青綠、蜷縮、濁響、清晰、凹陷、硬滑、好瓜}、{青綠、蜷縮、濁響、稍糊、稍凹、軟粘、好瓜}、{青綠、蜷縮、濁響、模糊、凹陷、軟粘、壞瓜}},然後假設我們根據屬性“紋理”=“凹陷”進行劃分,則子集c={{青綠、蜷縮、濁響、清晰、凹陷、硬滑、好瓜}、{青綠、蜷縮、濁響、模糊、凹陷、軟粘、壞瓜}}。
然後產生一個節點,(1)若子集c為空,則將該節點標記為葉節點,標記的類別為訓練集D中類別最多的類。(2)若c不為空,則將屬性a從屬性集A中剔除,將剩下的屬性集合子集c,從第一步開始繼續劃分(即遞迴)。
下面說明最優屬性a是如何劃分的。首先是資訊熵的概念,資訊熵的定義如下:
其中,n為訓練集D的類別數(如子集中的類別為好瓜、壞瓜,則n=2)。pk為第k個類別的樣本在訓練集中的比例。並規定若,則。
有了資訊熵就可以寫出資訊增益了,資訊增益定義為:
其中,V為屬性a的可能取值個數,Dv為屬性a對應的屬性值劃分出來的訓練子集(比如上面提到的c子集),D為訓練集。
該演算法劃分最優屬性a是根據資訊增益來劃分的,即資訊增益越大,說明以a作為下一個屬性來生成決策樹最佳。
舉個例子以便理解。
若訓練集D={{青綠 蜷縮 濁響 清晰 凹陷 硬滑 好瓜} {烏黑 蜷縮 沉悶 清晰 凹陷 硬滑 好瓜 } {青綠 蜷縮 沉悶 稍糊 稍凹 硬滑 壞瓜}},則n=2,資訊熵為
以計算"色澤"資訊增益為例(D中色澤屬性值有:青綠*2,烏黑*1):
2、C++演算法實現
首先我把西瓜書上的訓練集列出來,對理解程式有幫助。
(1)輸入函式
輸入函式我沒有在類內定義,而是直接在主函式中定義,因為我感覺這樣比較好,輸入與類分開。我們可以看到,資料是非常多的,有屬性以及各個樣本的各屬性對應的屬性值。我這裡是主要採用map這個結構來對這張表進行儲存。廢話不多說,直接上程式:
//全域性變數
//定義屬性陣列,存放可能的屬性,包括類別
vector<string> data_Attributes;//對於本資料集來說就是:色澤 根蒂 敲聲 紋理 臍部 觸感 類別
//定義各屬性對應的屬性值
map<string, vector<string>> data_AttValues;//比如:色澤={青綠 烏黑 淺白}
//定義剩餘屬性,不包括類別(這個主要用於後面演算法的遞迴)
vector<string> remain_Attributes;//色澤 根蒂 敲聲 紋理 臍部 觸感
//定義資料表,屬性-屬性值(全部資料的屬性值放在同一個陣列)
map<string, vector<string>>data_Table;//整張表
//輸入資料生成資料集
void data_Input()
{
//輸入屬性(色澤 根蒂 敲聲 紋理 臍部 觸感 好瓜)
string input_Line,temp_Attributes;
cout << "請輸入屬性:" << endl;
//獲取一行資料,然後繫結到資料流istringstream
getline(cin, input_Line);
istringstream input_Attributes(input_Line);
//將資料流內容(空格不輸出)輸入資料屬性陣列中
while (input_Attributes >> temp_Attributes)
{
data_Attributes.push_back(temp_Attributes);
}
//剔除類別這個屬性
remain_Attributes = data_Attributes;
remain_Attributes.pop_back();
//定義樣本數量
int N = 0;
cout << "請輸入樣本數量:" << endl;
cin >> N;
cin.ignore();//清空cin緩衝區中的留下的換行符
//輸入資料(屬性值)
cout << "請輸入樣本:" << endl;
//一共N個訓練樣本
for (int j = 0; j < N; j++)
{
string temp_AttValues;
//獲取一行屬性值輸入
getline(cin, input_Line);
istringstream input_AttValues(input_Line);
//將各屬性值輸入到資料表data_table中
for (int i = 0; i < data_Attributes.size(); i++)
{
input_AttValues >> temp_AttValues;
data_Table[data_Attributes[i]].push_back(temp_AttValues);
}
}
//生成各屬性對應的屬性值集的對映data_AttValues
for (int i = 0; i < data_Attributes.size(); i++)
{
//通過set結構來統計所有樣本中各屬性對應的屬性值的所有可能的取值
//如:“色澤”的可能取值為:青綠 烏黑 淺白
set<string> attValues;
for (int j = 0; j < N; j++)
{
//注意:data_Attributes[i]代表某個屬性
//而data_Table[data_Attributes[i]]是一個數組
string temp = data_Table[data_Attributes[i]][j];
//若有重複屬性值,set是不會插入的
attValues.insert(temp);
}
for (set<string>::iterator it = attValues.begin(); it != attValues.end(); it++)
{
//將所有可能的屬性值存入data_AttValues[data_Attributes[i]]
data_AttValues[data_Attributes[i]].push_back(*it);
}
}
}
(2)決策樹節點類的設計
決策樹類需要包含的東西挺多的,成員變數主要有:
樣本資料集的屬性個數:attribute_Num
本節點的屬性:node_Attribute
本節點屬性對應的所有可取的屬性值:node_AttValues
資料集的屬性:data_Attribute
從根節點到本節點未被用於最優劃分屬性的屬性集:remain_Attributes
本節點屬性對應的屬性值與子節點的地址的對映集,即本節點屬性取屬性值後下一個節點的地址的集合(有多個屬性值,所以有多個不同的地址):childNode(為空說明該節點是葉節點)
該節點對應的樣本的資料表:MyDateTable
各屬性對應的屬性值(外部傳進來的,對後面操作有作用):data_AttValues
成員函式主要有:
計算資訊熵函式:calc_Entropy()
計算資訊增益並尋找最優劃分屬性a:findBestAttribute()
生成本節點的子節點:generate_ChildNode()
設定節點的屬性:set_NodeAttribute()
訓練成完整的決策樹後,可根據所給樣本的屬性,預測出該樣本的類別:findClass()
class Tree_Node
{
public:
//建構函式,引數依次為:資料集表(西瓜資料表)、西瓜所有的屬性包括類別、每個屬性可能的取值構成的表、剩餘的未被劃分的屬性
Tree_Node(map<string, vector<string>> temp_Table, vector<string> temp_Attribute,map<string, vector<string>> data_AttValues, vector<string> temp_remain);
//生成子節點
void generate_ChildNode();
//計算資訊增益 尋找最優劃分屬性
string findBestAttribute();
//計算資訊熵
double calc_Entropy(map<string, vector<string>> temp_Table);
//設定節點的屬性
void set_NodeAttribute(string atttribute);
//根據所給屬性,對資料進行分類
string findClass(vector<string> attributes);
virtual ~Tree_Node();
private:
//屬性個數,不包括類別
int attribute_Num;
//本節點的屬性
string node_Attribute;
//資料集屬性
vector<string> data_Attribute;
//本節點的所有屬性值
vector<string> node_AttValues;
//剩餘屬性集
vector<string>remain_Attributes;
//子節點,本節點屬性對應的屬性值與子節點地址進行一一對映
//為空說明該節點為葉節點
map<string, Tree_Node *> childNode;
//樣本集合表
map<string, vector<string>> MyDateTable;
//定義各屬性對應的屬性值
map<string, vector<string>> data_AttValues;
};
(3)類的實現
首先是類的建構函式,建構函式主要是對類的成員變數進行初始化:
Tree_Node::Tree_Node(map<string, vector<string>> temp_Table,vector<string> temp_Attribute, map<string, vector<string>> data_AttValues, vector<string> temp_remain)
{
//全部屬性,包括類別
data_Attribute = temp_Attribute;
//屬性個數,不包括類別
attribute_Num = (int)temp_Attribute.size() - 1;
//各屬性對應的屬性值
this->data_AttValues = data_AttValues;
//屬性表
MyDateTable = temp_Table;
//剩餘屬性集
remain_Attributes = temp_remain;
}
然後是計算資訊熵的成員函式的實現:
//計算資訊熵
double Tree_Node::calc_Entropy(map<string, vector<string>> temp_Table)
{
map<string, vector<string>> table = temp_Table;
//資料集中樣本的數量
int sample_Num = (int)temp_Table[data_Attribute[0]].size();
//計算資料集中的類別數量
map<string, int> class_Map;
for (int i = 0; i < sample_Num; i++)
{
//data_Attribute[attribute_Num]對應的就是資料集的類別
string class_String = table[data_Attribute[attribute_Num]][i];
class_Map[class_String]++;
}
map<string, int>::iterator it = class_Map.begin();
//存放類別及其對應的數量
//vector<string> m_Class;
vector<int> n_Class;
for (; it != class_Map.end(); it++)
{
//m_Class.push_back(it->first);
n_Class.push_back(it->second);
}
//計算資訊熵
double Ent = 0;
for (int k = 0; k < class_Map.size(); k++)
{
//比例
double p = (double) n_Class[k] / sample_Num;
if (p == 0)
{
//規定了p=0時,plogp=0
continue;
}
//c++中只有log和ln,因此需要應用換底公式
Ent -= p * (log(p) / log(2));//資訊熵
}
return Ent;
}
接下來實現資訊增益的計算以及尋找出最優的劃分屬性:
//尋找最優劃分
string Tree_Node::findBestAttribute()
{
//樣本個數
int N = (int)MyDateTable[data_Attribute[0]].size();
//定義用於存放最優屬性
string best_Attribute;
//資訊增益
double gain = 0;
//對每個剩餘屬性
for (int i = 0; i < remain_Attributes.size(); i++)
{
//定義資訊增益,選取增益最大的屬性來劃分即為最優劃分
double temp_Gain = calc_Entropy(MyDateTable);//根據公式先將本節點的資訊熵初始化給增益
//對該屬性的資料集進行分類(獲取各屬性值的資料子集)
string temp_Att = remain_Attributes[i];//假設選取的屬性
vector<string> remain_AttValues;//屬性可能的取值
for (int j = 0; j < data_AttValues[temp_Att].size(); j++)
{
remain_AttValues.push_back(data_AttValues[temp_Att][j]);
}
//對每個屬性值求資訊熵
for (int k = 0; k < remain_AttValues.size(); k++)
{
//屬性值
string temp_AttValues = remain_AttValues[k];
int sample_Num = 0;//該屬性值對應樣本數量
//定義map用來存放該屬性值下的資料子集
map<string, vector<string>>sub_DataTable;
for (int l = 0; l < MyDateTable[temp_Att].size(); l++)
{
if (temp_AttValues == MyDateTable[temp_Att][l])
{
sample_Num++;
//將符合條件的訓練集存入sub_DataTable
for (int m = 0; m < data_Attribute.size(); m++)
{
sub_DataTable[data_Attribute[m]].push_back(MyDateTable[data_Attribute[m]][l]);
}
}
}
//累加每個屬值的資訊熵
temp_Gain -= (double)sample_Num / N * calc_Entropy(sub_DataTable);
}
//比較尋找最優劃分屬性
if (temp_Gain > gain)
{
gain = temp_Gain;
best_Attribute = temp_Att;
}
}
return best_Attribute;
}
然後是實現如何生成子節點的成員函式,在這之前先實現一個設定節點的屬性的函式,該函式主要用於設定子節點的節點屬性:
void Tree_Node::set_NodeAttribute(string attribute)
{
//設定節點的屬性
this->node_Attribute = attribute;
}
生成子節點,生成方式就是按照西瓜書上面的演算法一步一步實現即可,程式碼如下:
void Tree_Node::generate_ChildNode()
{
//樣本個數
int N = (int)MyDateTable[data_Attribute[0]].size();
//將資料集中類別種類和數量放入map裡面,只需判斷最後一列即可
map<string,int> category;
for (int i = 0; i < N; i++)
{
vector<string> temp_Class;
temp_Class = MyDateTable[data_Attribute[attribute_Num]];
category[temp_Class[i]]++;
}
//第一種情況
//只有一個類別,標記為葉節點
if (1 == category.size())
{
map<string, int>::iterator it = category.begin();
node_Attribute = it->first;
return;
}
//第二種情況
//先判斷所有屬性是否取相同值
bool isAllSame = false;
for (int i = 0; i < remain_Attributes.size(); i++)
{
isAllSame = true;
vector<string> temp;
temp = MyDateTable[remain_Attributes[i]];
for (int j = 1; j < temp.size(); j++)
{
//只要有一個不同,即可退出
if (temp[0] != temp[j])
{
isAllSame = false;
break;
}
}
if (isAllSame == false)
{
break;
}
}
//若屬性集為空或者樣本中的全部屬性取值相同
if (remain_Attributes.empty()||isAllSame)
{
//找出數量最多的類別及其出現的個數,並將該節點標記為該類
map<string, int>::iterator it = category.begin();
node_Attribute = it->first;
int max = it->second;
it++;
for (; it != category.end(); it++)
{
int num = it->second;
if (num > max)
{
node_Attribute = it->first;
max = num;
}
}
return;
}
//第三種情況
//從remian_attributes中劃分最優屬性
string best_Attribute = findBestAttribute();
//將本節點設定為最優屬性
node_Attribute = best_Attribute;
//對最優屬性的每個屬性值
for (int i = 0; i < data_AttValues[best_Attribute].size(); i++)
{
string best_AttValues = data_AttValues[best_Attribute][i];
//計算屬性對應的資料集D
//定義map用來存放該屬性值下的資料子集
map<string, vector<string>> sub_DataTable;
for (int j = 0; j < MyDateTable[best_Attribute].size(); j++)
{
//尋找最優屬性在資料集中屬性值相同的資料樣本
if (best_AttValues == MyDateTable[best_Attribute][j])
{
//找到對應的資料集,存入子集中sub_DataTable(該樣本的全部屬性都要存入)
for (int k = 0; k < data_Attribute.size(); k++)
{
sub_DataTable[data_Attribute[k]].push_back(MyDateTable[data_Attribute[k]][j]);
}
}
}
//若子集為空,將分支節點(子節點)標記為葉節點,類別為MyDateTable樣本最多的類
if (sub_DataTable.empty())
{
//生成子節點
Tree_Node * p = new Tree_Node(sub_DataTable, data_Attribute, data_AttValues, remain_Attributes);
//找出樣本最多的類,作為子節點的屬性
map<string, int>::iterator it = category.begin();
string childNode_Attribute = it->first;
int max_Num = it->second;
it++;
for (; it != category.end(); it++)
{
if (it->second > max_Num)
{
max_Num = it->second;
childNode_Attribute = it->first;
}
}
//設定子葉節點屬性
p->set_NodeAttribute(childNode_Attribute);
//將子節點存入childNode,預測樣本的時候會用到
childNode[best_AttValues] = p;
}
else//若不為空,則從剩餘屬性值剔除該屬性,呼叫generate_ChildNode繼續往下細分
{
vector<string> child_RemainAtt;
child_RemainAtt = remain_Attributes;
//找出child_RemainAtt中的與該最佳屬性相等的屬性
vector<string>::iterator it = child_RemainAtt.begin();
for (; it != child_RemainAtt.end(); it++)
{
if (*it == best_Attribute)
{
break;
}
}
//刪除
child_RemainAtt.erase(it);
//生成子節點
Tree_Node * pt = new Tree_Node(sub_DataTable, data_Attribute, data_AttValues, child_RemainAtt);
//將子節點存入childNode
childNode[best_AttValues] = pt;
//子節點再呼叫generate_ChildNode函式
pt->generate_ChildNode();
}
}
}
在最後,我們必須有個預測函式,即如果輸入一個樣本,需要給出該樣本的是什麼類別的。比如:給出西瓜資料:青綠 蜷縮 濁響 清晰 凹陷 硬滑,則呼叫該函式應該輸出:“好瓜”。該函式實現如下:
//輸入為待預測樣本的所有屬性集合
string Tree_Node::findClass(vector<string> attributes)
{
//若存在子節點
if (childNode.size() != 0)
{
//找出輸入的樣例中與本節點屬性對應的屬性值,以便尋找下個節點,直到找到葉節點
string attribute_Value;
for (int i = 0; i < data_AttValues[node_Attribute].size(); i++)
{
for (int j = 0; j < attributes.size(); j++)
{
//data_AttValues[node_Attribute]為屬性node_Attribute對應的所有可能的取值集合
if (attributes[j] == data_AttValues[node_Attribute][i])
{
//找到了樣例對應的屬性值
attribute_Value = attributes[j];
break;
}
}
//找到後就沒必要繼續迴圈了
if (!attribute_Value.empty())
{
break;
}
}
//找出該屬性值對應的子節點的地址,以便進行訪問
Tree_Node *p = childNode[attribute_Value];
return p->findClass(attributes);//遞迴尋找,直到找到葉節點為止
}
else//不存在子節點說明已經找到分類,類別為本節點的node_Attribute
{
return node_Attribute;
}
}
3、預測結果
先放上我的主函式,主函式主要是呼叫函式,然後格式化輸出。
int main()
{
//輸入
data_Input();
Tree_Node myTree(data_Table, data_Attributes, data_AttValues, remain_Attributes);
//進行訓練
myTree.generate_ChildNode();
//輸入預測樣例,進行預測
vector<string> predict_Sample;
string input_Line, temp;
cout << "請輸入屬性進行預測:" << endl;
getline(cin, input_Line);
istringstream input_Sample(input_Line);
while (input_Sample >> temp)
{
//將輸入預測樣例的屬性都存入predict_Sample,以便傳參
predict_Sample.push_back(temp);
}
cout << endl;
//預測
cout << "分類結果為:" << myTree.findClass(predict_Sample) << endl;
system("pause");
return 0;
}
執行結果:
可以看出,預測結果是正確的。但真正的預測應該不能取樣本中已有的樣例,我這樣做是為了驗證程式的正確性。
好了,文章到此就結束了。關於程式的相關優化和剪枝操作以後有時間再來完善。下面附上源程式程式碼供下載:https://download.csdn.net/download/m0_37543178/10793382。(ps:沒有積分的可以直接找我給你原始碼:))