1. 程式人生 > >基於C++語言的決策樹實現

基於C++語言的決策樹實現

   感覺好久都沒有寫過程式了,一直上課沒有時間。最近有點空,然後就寫了下西瓜書中的決策樹的實現。由於本人才疏學淺,採用的實現方式和資料結構可能不合理,沒有考慮程式碼的複雜度和時間複雜度等等,僅寫下自己的實現想法(大神們就打擾了)。該程式是基於C++語言來實現的,演算法就是西瓜書上面的實現演算法,採用最簡單的ID3演算法,用資訊增益來選擇最優劃分,進而進行決策樹的實現(沒有對決策樹進行剪枝操作,以後有時間再改進)。

1、基本概念

決策樹的詳細演算法我就不介紹了,具體參看西瓜書。在這裡,我想寫一下演算法的大體。首先說明一下屬性、屬性值、屬性集、類別、訓練集的概念。

1、屬性:屬性就是用來描述物體特徵的量。比如:描述一個西瓜可以用色澤、根蒂、敲聲、紋理、臍部、觸感等等屬性來說明一個西瓜是說明樣子的。

2、屬性值:屬性值是屬性的取值。比如:一個西瓜的色澤可以用青綠、烏黑、淺白來描述,這些就是西瓜的色澤屬性的屬性值。

3、屬性集:屬性集是物體所有屬性的集合。顧名思義,就是上述列出的屬性打包在一起組成的集合。

4、類別:物體的種類,決策樹的目的就是根據屬性來對物體進行分類。比如:西瓜有好瓜和壞瓜。

5、訓練集:訓練集就是樣本集,所有的學習都需要通過資料來形成,決策樹也是一樣的。需要通過樣本來形成一棵決策樹。訓練集必須包括物體各個屬性的屬性值和類別。

然後說明下演算法的三種情況:

設訓練集為D,屬性集為A,需要生成樹的節點。

1、若訓練集D中的全部樣本的類別都是一樣的,比如都是好瓜或者都是壞瓜。這時候將節點的值標記為樣本的類別。

2、若屬性集A中的屬性為空,即樹的分支已經到底了,樣本的所有屬性已經用完,這時候需要將節點標記為葉節點。或者有訓練集D中的所有樣本,在屬性集A上的取值是一樣的。比如屬性集A={色澤、根蒂、敲聲},而訓練集D={{青綠、蜷縮、濁響、清晰、凹陷、硬滑、好瓜}、{青綠、蜷縮、濁響、稍糊、稍凹、軟粘、好瓜}、{青綠、蜷縮、濁響、模糊、凹陷、軟粘、壞瓜}},這時可以看出,訓練集D在屬性集A上的取值都為:青綠、蜷縮、濁響。這兩種情況需要將節點標記為葉節點,葉節點的屬性值為訓練集中類別最多的類。如上述例子,就需要將葉節點標記為好瓜。

3、在第三步之前需要先從屬性集A中選擇一個最優劃分屬性a,如何選擇最優屬性a,下面再說明。

假設選擇最優屬性a的值的集合為{b1、b2、b3}。例如,假設選擇的屬性是a=色澤,而其對應的屬性集為b={青綠、烏黑、淺白}。

這時候,我們需要對每個屬性值從訓練集D中劃分出屬性a的取值為b1(b2、b3)的子集c。例如,我們的訓練集D={{青綠、蜷縮、濁響、清晰、凹陷、硬滑、好瓜}、{青綠、蜷縮、濁響、稍糊、稍凹、軟粘、好瓜}、{青綠、蜷縮、濁響、模糊、凹陷、軟粘、壞瓜}},然後假設我們根據屬性“紋理”=“凹陷”進行劃分,則子集c={{青綠、蜷縮、濁響、清晰、凹陷、硬滑、好瓜}、{青綠、蜷縮、濁響、模糊、凹陷、軟粘、壞瓜}}。

然後產生一個節點,(1)若子集c為空,則將該節點標記為葉節點,標記的類別為訓練集D中類別最多的類。(2)若c不為空,則將屬性a從屬性集A中剔除,將剩下的屬性集合子集c,從第一步開始繼續劃分(即遞迴)。

下面說明最優屬性a是如何劃分的。首先是資訊熵的概念,資訊熵的定義如下:

                                            Ent(D)=-\sum_{k=1}^{n}p_{k}log_{2}p_{k} 

其中,n為訓練集D的類別數(如子集中的類別為好瓜、壞瓜,則n=2)。pk為第k個類別的樣本在訓練集中的比例。並規定若p_{k}=0,則p_{k}log_{2}p_{k}=0

有了資訊熵就可以寫出資訊增益了,資訊增益定義為:

                                            Gain(D,a)=Ent(D)-\sum_{v=1}^{V}\frac{\mathrm{D_{v}} }{\mathrm{D} }Ent(D_{v})

其中,V為屬性a的可能取值個數,Dv為屬性a對應的屬性值劃分出來的訓練子集(比如上面提到的c子集),D為訓練集。

該演算法劃分最優屬性a是根據資訊增益來劃分的,即資訊增益越大,說明以a作為下一個屬性來生成決策樹最佳。

 舉個例子以便理解。

若訓練集D={{青綠 蜷縮 濁響 清晰 凹陷 硬滑 好瓜} {烏黑 蜷縮 沉悶 清晰 凹陷 硬滑 好瓜 } {青綠 蜷縮 沉悶 稍糊 稍凹 硬滑 壞瓜}},則n=2,資訊熵為

                                             Ent(D)=-\left ( \frac{2}{3}*log_{2} \frac{2}{3}+\frac{1}{3}*log_{2}\frac{1}{3}\right )

以計算"色澤"資訊增益為例(D中色澤屬性值有:青綠*2,烏黑*1):

                      Gain(D,a)=Ent(D)+(\frac{2}{3}*\left (\frac{1}{2} *log_{2}\frac{1}{2}+\frac{1}{2} *log_{2}\frac{1}{2}\right )+\frac{1}{3}\left (1*log_{2}1+0\right ))

2、C++演算法實現

首先我把西瓜書上的訓練集列出來,對理解程式有幫助。

                        

(1)輸入函式

輸入函式我沒有在類內定義,而是直接在主函式中定義,因為我感覺這樣比較好,輸入與類分開。我們可以看到,資料是非常多的,有屬性以及各個樣本的各屬性對應的屬性值。我這裡是主要採用map這個結構來對這張表進行儲存。廢話不多說,直接上程式:

//全域性變數
//定義屬性陣列,存放可能的屬性,包括類別
vector<string> data_Attributes;//對於本資料集來說就是:色澤 根蒂 敲聲 紋理 臍部 觸感 類別
//定義各屬性對應的屬性值
map<string, vector<string>> data_AttValues;//比如:色澤={青綠 烏黑 淺白}
//定義剩餘屬性,不包括類別(這個主要用於後面演算法的遞迴)
vector<string> remain_Attributes;//色澤 根蒂 敲聲 紋理 臍部 觸感
//定義資料表,屬性-屬性值(全部資料的屬性值放在同一個陣列)
map<string, vector<string>>data_Table;//整張表


//輸入資料生成資料集
void data_Input()
{
	//輸入屬性(色澤 根蒂 敲聲 紋理 臍部 觸感 好瓜)
	string input_Line,temp_Attributes;
	cout << "請輸入屬性:" << endl;
	//獲取一行資料,然後繫結到資料流istringstream
	getline(cin, input_Line);
	istringstream input_Attributes(input_Line);
	//將資料流內容(空格不輸出)輸入資料屬性陣列中
	while (input_Attributes >> temp_Attributes)
	{
		data_Attributes.push_back(temp_Attributes);
	}
	//剔除類別這個屬性
	remain_Attributes = data_Attributes;
	remain_Attributes.pop_back();
	//定義樣本數量
	int N = 0;
	cout << "請輸入樣本數量:" << endl;
	cin >> N;
	cin.ignore();//清空cin緩衝區中的留下的換行符
	//輸入資料(屬性值)
	cout << "請輸入樣本:" << endl;
	//一共N個訓練樣本
	for (int j = 0; j < N; j++)
	{
		string temp_AttValues;
		//獲取一行屬性值輸入
		getline(cin, input_Line);
		istringstream input_AttValues(input_Line);
		//將各屬性值輸入到資料表data_table中
		for (int i = 0; i < data_Attributes.size(); i++)
		{
			input_AttValues >> temp_AttValues;
			data_Table[data_Attributes[i]].push_back(temp_AttValues);
		}
	}

	//生成各屬性對應的屬性值集的對映data_AttValues
	for (int i = 0; i < data_Attributes.size(); i++)
	{
        //通過set結構來統計所有樣本中各屬性對應的屬性值的所有可能的取值
        //如:“色澤”的可能取值為:青綠 烏黑 淺白
		set<string> attValues;
		for (int j = 0; j < N; j++)
		{
            //注意:data_Attributes[i]代表某個屬性
            //而data_Table[data_Attributes[i]]是一個數組
			string temp = data_Table[data_Attributes[i]][j];
			//若有重複屬性值,set是不會插入的
			attValues.insert(temp);
		}
		for (set<string>::iterator it = attValues.begin(); it != attValues.end(); it++)
		{
            //將所有可能的屬性值存入data_AttValues[data_Attributes[i]]
			data_AttValues[data_Attributes[i]].push_back(*it);
		}
		
	}
}

(2)決策樹節點類的設計

決策樹類需要包含的東西挺多的,成員變數主要有

樣本資料集的屬性個數:attribute_Num

本節點的屬性:node_Attribute

本節點屬性對應的所有可取的屬性值:node_AttValues

資料集的屬性:data_Attribute

從根節點到本節點未被用於最優劃分屬性的屬性集:remain_Attributes

本節點屬性對應的屬性值與子節點的地址的對映集,即本節點屬性取屬性值後下一個節點的地址的集合(有多個屬性值,所以有多個不同的地址):childNode(為空說明該節點是葉節點)

該節點對應的樣本的資料表:MyDateTable

各屬性對應的屬性值(外部傳進來的,對後面操作有作用):data_AttValues

成員函式主要有

計算資訊熵函式:calc_Entropy()

計算資訊增益並尋找最優劃分屬性a:findBestAttribute()

生成本節點的子節點:generate_ChildNode()

設定節點的屬性:set_NodeAttribute()

訓練成完整的決策樹後,可根據所給樣本的屬性,預測出該樣本的類別:findClass()

class Tree_Node
{
public:
	//建構函式,引數依次為:資料集表(西瓜資料表)、西瓜所有的屬性包括類別、每個屬性可能的取值構成的表、剩餘的未被劃分的屬性
	Tree_Node(map<string, vector<string>> temp_Table, vector<string> temp_Attribute,map<string, vector<string>> data_AttValues, vector<string> temp_remain);
	//生成子節點
	void generate_ChildNode();
	//計算資訊增益 尋找最優劃分屬性
	string findBestAttribute();
	//計算資訊熵
	double calc_Entropy(map<string, vector<string>> temp_Table);
	//設定節點的屬性
	void set_NodeAttribute(string atttribute);
	//根據所給屬性,對資料進行分類
	string findClass(vector<string> attributes);
	virtual ~Tree_Node();
private:
	//屬性個數,不包括類別
	int attribute_Num;
	//本節點的屬性
	string node_Attribute;
	//資料集屬性
	vector<string> data_Attribute;
	//本節點的所有屬性值
	vector<string> node_AttValues;
	//剩餘屬性集
	vector<string>remain_Attributes;
	//子節點,本節點屬性對應的屬性值與子節點地址進行一一對映
	//為空說明該節點為葉節點
	map<string, Tree_Node *> childNode;
	//樣本集合表
	map<string, vector<string>> MyDateTable;
	//定義各屬性對應的屬性值
	map<string, vector<string>> data_AttValues;
};

(3)類的實現

首先是類的建構函式,建構函式主要是對類的成員變數進行初始化:

Tree_Node::Tree_Node(map<string, vector<string>> temp_Table,vector<string> temp_Attribute, map<string, vector<string>> data_AttValues, vector<string> temp_remain)
{
	//全部屬性,包括類別
	data_Attribute = temp_Attribute;
	//屬性個數,不包括類別
	attribute_Num = (int)temp_Attribute.size() - 1;
	//各屬性對應的屬性值
	this->data_AttValues = data_AttValues;
	//屬性表
	MyDateTable = temp_Table;
	//剩餘屬性集
	remain_Attributes = temp_remain;
}

然後是計算資訊熵的成員函式的實現:

//計算資訊熵
double Tree_Node::calc_Entropy(map<string, vector<string>> temp_Table)
{
	map<string, vector<string>> table = temp_Table;
	//資料集中樣本的數量
	int sample_Num = (int)temp_Table[data_Attribute[0]].size();
	//計算資料集中的類別數量
	map<string, int> class_Map;
	for (int i = 0; i < sample_Num; i++)
	{
		//data_Attribute[attribute_Num]對應的就是資料集的類別
		string class_String = table[data_Attribute[attribute_Num]][i];
		class_Map[class_String]++;
	}

	map<string, int>::iterator it = class_Map.begin();
	//存放類別及其對應的數量
	//vector<string> m_Class;
	vector<int> n_Class;
	
	for (; it != class_Map.end(); it++)
	{
		//m_Class.push_back(it->first);
		n_Class.push_back(it->second);
	}
	//計算資訊熵
	double Ent = 0;
	for (int k = 0; k < class_Map.size(); k++)
	{
		//比例
		double p = (double) n_Class[k] / sample_Num;
		if (p == 0)
		{
			//規定了p=0時,plogp=0
			continue;
		}
		//c++中只有log和ln,因此需要應用換底公式
		Ent -= p * (log(p) / log(2));//資訊熵
	}
	
	return Ent;
}

接下來實現資訊增益的計算以及尋找出最優的劃分屬性:

//尋找最優劃分
string Tree_Node::findBestAttribute()
{
	//樣本個數
	int N = (int)MyDateTable[data_Attribute[0]].size();
	//定義用於存放最優屬性
	string best_Attribute;
	//資訊增益
	double gain = 0;
	//對每個剩餘屬性
	for (int i = 0; i < remain_Attributes.size(); i++)
	{
		//定義資訊增益,選取增益最大的屬性來劃分即為最優劃分
		double temp_Gain = calc_Entropy(MyDateTable);//根據公式先將本節點的資訊熵初始化給增益
		//對該屬性的資料集進行分類(獲取各屬性值的資料子集)
		string temp_Att = remain_Attributes[i];//假設選取的屬性
		vector<string> remain_AttValues;//屬性可能的取值
		for (int j = 0; j < data_AttValues[temp_Att].size(); j++)
		{
			remain_AttValues.push_back(data_AttValues[temp_Att][j]);
		}
		
		//對每個屬性值求資訊熵
		for (int k = 0; k < remain_AttValues.size(); k++)
		{
			//屬性值
			string temp_AttValues = remain_AttValues[k];
			int sample_Num = 0;//該屬性值對應樣本數量
			//定義map用來存放該屬性值下的資料子集
			map<string, vector<string>>sub_DataTable;
			for (int l = 0; l < MyDateTable[temp_Att].size(); l++)
			{
				if (temp_AttValues == MyDateTable[temp_Att][l])
				{
					sample_Num++;
					//將符合條件的訓練集存入sub_DataTable
					for (int m = 0; m < data_Attribute.size(); m++)
					{
						sub_DataTable[data_Attribute[m]].push_back(MyDateTable[data_Attribute[m]][l]);
					}
				}
			}
			//累加每個屬值的資訊熵
			temp_Gain -= (double)sample_Num / N * calc_Entropy(sub_DataTable);
		}
		//比較尋找最優劃分屬性
		if (temp_Gain > gain)
		{
			gain = temp_Gain;
			best_Attribute = temp_Att;
		}		
	}

	return best_Attribute;
}

然後是實現如何生成子節點的成員函式,在這之前先實現一個設定節點的屬性的函式,該函式主要用於設定子節點的節點屬性:

void Tree_Node::set_NodeAttribute(string attribute)
{
	//設定節點的屬性
	this->node_Attribute = attribute;
}

生成子節點,生成方式就是按照西瓜書上面的演算法一步一步實現即可,程式碼如下:

void Tree_Node::generate_ChildNode()
{
	//樣本個數
	int N = (int)MyDateTable[data_Attribute[0]].size();
	
	//將資料集中類別種類和數量放入map裡面,只需判斷最後一列即可
	map<string,int> category;
	for (int i = 0; i < N; i++)
	{	
		vector<string> temp_Class;
		temp_Class = MyDateTable[data_Attribute[attribute_Num]];
		category[temp_Class[i]]++;
	}

	//第一種情況
	//只有一個類別,標記為葉節點
	if (1 == category.size())
	{
		map<string, int>::iterator it = category.begin();
		node_Attribute = it->first;
		return;
	}
	//第二種情況
	//先判斷所有屬性是否取相同值
	bool isAllSame = false;
	for (int i = 0; i < remain_Attributes.size(); i++)
	{
		isAllSame = true;
		vector<string> temp;
		temp = MyDateTable[remain_Attributes[i]];
		for (int j = 1; j < temp.size(); j++)
		{
			//只要有一個不同,即可退出
			if (temp[0] != temp[j])
			{
				isAllSame = false;
				break;
			}
		}
		if (isAllSame == false)
		{
			break;
		}
	}
	//若屬性集為空或者樣本中的全部屬性取值相同
	if (remain_Attributes.empty()||isAllSame)
	{
		//找出數量最多的類別及其出現的個數,並將該節點標記為該類
		map<string, int>::iterator it = category.begin();
		node_Attribute = it->first;
		int max = it->second;
		it++;
		for (; it != category.end(); it++)
		{
			int num = it->second;
			if (num > max)
			{
				node_Attribute = it->first;
				max = num;
			}
		}
		return;
	}
	//第三種情況
	//從remian_attributes中劃分最優屬性
	string best_Attribute = findBestAttribute();
	//將本節點設定為最優屬性
	node_Attribute = best_Attribute;
	//對最優屬性的每個屬性值
	for (int i = 0; i < data_AttValues[best_Attribute].size(); i++)
	{
		string best_AttValues = data_AttValues[best_Attribute][i];
		//計算屬性對應的資料集D
		//定義map用來存放該屬性值下的資料子集
		map<string, vector<string>> sub_DataTable;
		for (int j = 0; j < MyDateTable[best_Attribute].size(); j++)
		{
			//尋找最優屬性在資料集中屬性值相同的資料樣本
			if (best_AttValues == MyDateTable[best_Attribute][j])
			{
				//找到對應的資料集,存入子集中sub_DataTable(該樣本的全部屬性都要存入)
				for (int k = 0; k < data_Attribute.size(); k++)
				{
					sub_DataTable[data_Attribute[k]].push_back(MyDateTable[data_Attribute[k]][j]);
				}
			}
		}
		//若子集為空,將分支節點(子節點)標記為葉節點,類別為MyDateTable樣本最多的類
		if (sub_DataTable.empty())
		{
			//生成子節點
			Tree_Node * p = new Tree_Node(sub_DataTable, data_Attribute, data_AttValues, remain_Attributes);
			//找出樣本最多的類,作為子節點的屬性
			map<string, int>::iterator it = category.begin();
			string childNode_Attribute = it->first;
			int max_Num = it->second;
			it++;
			for (; it != category.end(); it++)
			{
				if (it->second > max_Num)
				{
					max_Num = it->second;
					childNode_Attribute = it->first;
				}
			}
			//設定子葉節點屬性
			p->set_NodeAttribute(childNode_Attribute);
			//將子節點存入childNode,預測樣本的時候會用到
			childNode[best_AttValues] = p;
		}
		else//若不為空,則從剩餘屬性值剔除該屬性,呼叫generate_ChildNode繼續往下細分
		{
			vector<string> child_RemainAtt;
			child_RemainAtt = remain_Attributes;
			//找出child_RemainAtt中的與該最佳屬性相等的屬性
			vector<string>::iterator it = child_RemainAtt.begin();
			for (; it != child_RemainAtt.end(); it++)
			{
				if (*it == best_Attribute)
				{
					break;
				}
			}
			//刪除
			child_RemainAtt.erase(it);

			//生成子節點
			Tree_Node * pt = new Tree_Node(sub_DataTable, data_Attribute, data_AttValues, child_RemainAtt);
			//將子節點存入childNode
			childNode[best_AttValues] = pt;
			//子節點再呼叫generate_ChildNode函式
			pt->generate_ChildNode();
		}
	}

}

在最後,我們必須有個預測函式,即如果輸入一個樣本,需要給出該樣本的是什麼類別的。比如:給出西瓜資料:青綠 蜷縮 濁響 清晰 凹陷 硬滑,則呼叫該函式應該輸出:“好瓜”。該函式實現如下:

//輸入為待預測樣本的所有屬性集合
string Tree_Node::findClass(vector<string> attributes)
{
	//若存在子節點
	if (childNode.size() != 0)
	{
		//找出輸入的樣例中與本節點屬性對應的屬性值,以便尋找下個節點,直到找到葉節點
		string attribute_Value;
		for (int i = 0; i < data_AttValues[node_Attribute].size(); i++)
		{
			for (int j = 0; j < attributes.size(); j++)
			{
				//data_AttValues[node_Attribute]為屬性node_Attribute對應的所有可能的取值集合
				if (attributes[j] == data_AttValues[node_Attribute][i])
				{
					//找到了樣例對應的屬性值
					attribute_Value = attributes[j];
					break;
				}
			}
			//找到後就沒必要繼續迴圈了
			if (!attribute_Value.empty())
			{
				break;
			}
		}
		//找出該屬性值對應的子節點的地址,以便進行訪問
		Tree_Node *p = childNode[attribute_Value];
		return p->findClass(attributes);//遞迴尋找,直到找到葉節點為止
	}
	else//不存在子節點說明已經找到分類,類別為本節點的node_Attribute
	{
		return node_Attribute;
	}
}

3、預測結果

先放上我的主函式,主函式主要是呼叫函式,然後格式化輸出。

int main()
{
	//輸入
	data_Input();
	Tree_Node myTree(data_Table, data_Attributes, data_AttValues, remain_Attributes);
	//進行訓練
	myTree.generate_ChildNode();
	//輸入預測樣例,進行預測
	vector<string> predict_Sample;
	string input_Line, temp;
	cout << "請輸入屬性進行預測:" << endl;
	getline(cin, input_Line);
	istringstream input_Sample(input_Line);
	while (input_Sample >> temp)
	{
		//將輸入預測樣例的屬性都存入predict_Sample,以便傳參
		predict_Sample.push_back(temp);
	}
	cout << endl;
	//預測
	cout << "分類結果為:" << myTree.findClass(predict_Sample) << endl;
	system("pause");
	return 0;
}

執行結果:

可以看出,預測結果是正確的。但真正的預測應該不能取樣本中已有的樣例,我這樣做是為了驗證程式的正確性。

好了,文章到此就結束了。關於程式的相關優化和剪枝操作以後有時間再來完善。下面附上源程式程式碼供下載:https://download.csdn.net/download/m0_37543178/10793382。(ps:沒有積分的可以直接找我給你原始碼:))