1. 程式人生 > >決策樹的實現原理與matlab程式碼

決策樹的實現原理與matlab程式碼

很久不寫部落格了,感覺很長一段時間只是一味的看書,疏不知一味地看書、寫程式碼會導致自己的思考以及總結能力變得衰弱。所以,我決定還是繼續寫部落格。廢話不多說了,今天想主要記錄資料探勘中的決策樹。希望能夠將自己的理解寫得通俗易懂。

決策樹是一種對例項分類的樹形結構,樹中包含葉子節點與內部節點。內部節點主要是資料中的某一特性,葉子節點是根據資料分析後的最後結果。

先看一組資料:


這組資料的特性包含年齡、工作與否、是否有房、信貸情況以及最終分類結果貸款是否成功,共包含15組樣例。

構建數的形狀可以有多種,如下:

        

如果隨意構建樹,那將會導致有的構建樹比較龐大,分類時代價比較大,有的構建樹比較小,分類代價小。

比如針對是否有房這一列,發現如果樣本有房這一列為是,最終分類結果便是可以貸款,而不需要判斷其他的特性,便可以獲得最終部分分類結果。

因此,構建樹需要以最小的代價實現最快的分類。根據何種標準進行判別呢?

在資訊理論與概率統計中,熵是表示隨機變數不確定的量度,設x是一個取有限個值的離散隨機變數,其概率分佈為:

則隨機變數x的熵定義為


熵越大,其不確定性越大。

隨機變數x在給定條件y下的條件熵為H(y|x);

資訊增益表示得知特徵x的資訊而使得y類資訊的不確定減少的程度。

因此,特徵A對訓練集D的資訊增益g(D,A),定義為集合D的熵H(D)與特徵A給定條件下D的條件熵H(D|A)之差,即


對錶5.1給定的訓練資料集D,計算各特徵對其的資訊增益,分別以A1,A2,A3,A4表示年齡,有工作,有自己的房子和信貸情況四個特徵,則

(1)


(2)


這裡D1,D2,D3分別是D中A1取為青年、中年、老年的樣本子集,同理,求得其他特徵的資訊增益:




接下來根據之前的資訊增益,對決策樹進行生成,這裡主要使用ID3演算法,C4.5演算法與之類似,只是將資訊增益衡量轉為資訊增益比衡量。

主要方法如下:

從根節點開始,對節點計算所有可能的特徵的資訊增益,選擇資訊增益最大的特徵作為該節點的特徵,由該特徵的不同取值建立子節點,再對子節點遞迴呼叫以上方法,構建決策樹。

那麼遞迴何時停止呢?當訓練集中所有例項屬於同一類時,或者所有特徵都選擇完畢時,或者資訊增益小於某個閾值時,則停止遞迴,。

舉例來說,根據之前對錶5.1的熵的計算,由於A3(是否有自己的房子)資訊增益最大,所以以A3為決策樹的根節點的特徵,它將資料集分為兩個子集D1(A3取是)和D2(A3取否),由於D1的分類結果都是可以貸款,所以它成為葉節點,對於D2,則從特徵A1,A2,A4這三個特徵中重新選擇特徵,計算各個特徵的資訊增益:


因此選擇A2作為子樹節點,針對A2是否有工作這個特徵,根據樣本分類結果發現有工作與無工作各自的樣本都屬於同一類,因此將有工作與無工作作為子樹的葉節點。這樣便生成如下的決策樹:


決策樹生成演算法遞迴的產生決策樹,往往對訓練資料分類準確,但對未知資料卻沒那麼準確,即會出現過擬合狀況。解決這個問題可以通過決策樹的剪枝,讓決策樹簡化。本文暫不對決策樹的剪枝進行詳細描述。

接下來,即對決策樹實現的matlab程式碼:

1、首先,定義決策樹的資料結構

tree

{

int pro    //是葉節點(0表示)還是內部節點(1表示)

int value //如果是葉節點,則表示具體的分類結果,如果是內部節點,則表示某個特徵

int parentpro //如果該節點有父節點,則該值表示父節點所表示特徵的具體屬性值

 tree  child[]  //表示該節點的子樹陣列

}

2、根據訓練集資料通過遞迴形成樹:

function tree = maketree(featurelabels,trainfeatures,targets,epsino)
tree=struct('pro',0,'value',-1,'child',[],'parentpro',-1);
[n,m] = size(trainfeatures); %where n represent total numbers of features,m represent total numbers of examples
cn = unique(targets);%different classes
l=length(cn);%totoal numbers of classes
if l==1%if only one class,just use the class to be the lable of the tree and return
    tree.pro=0;%reprensent leaf
    tree.value = cn;
    tree.child=[];
    return
end
if n==0% if feature number equals 0
    H = hist(targets, length(cn)); %histogram of class
   [ma, largest] = max(H); %ma is the number of class who has largest number,largest is the posion in cn
   tree.pro=0;
   tree.value=cn(largest);
   tree.child=[];
   return
end


pnode = zeros(1,length(cn));
%calculate info gain
for i=1:length(cn)
    pnode(i)=length(find(targets==cn(i)))/length(targets);
end
H=-sum(pnode.*log(pnode)/log(2));
maxium=-1;
maxi=-1;
g=zeros(1,n);
for i=1:n
    fn=unique(trainfeatures(i,:));%one feature has fn properties
    lfn=length(fn);
    gf=zeros(1,lfn);
    fprintf('feature numbers:%d\n',lfn);
    for k=1:lfn
        onefeature=find(fn(k)==trainfeatures(i,:));%to each property in feature,,calucute the number of this property
        for j=1:length(cn)
            oneinonefeature=find(cn(j)==targets(:,onefeature));
            ratiofeature=length(oneinonefeature)/length(onefeature);
            fprintf('feature %d, property %d, rationfeature:%f\n',i, fn(k),ratiofeature);
            if(ratiofeature~=0)
                gf(k)=gf(k)+(-ratiofeature*log(ratiofeature)/log(2));
            end
        end  
        ratio=length(onefeature)/m;
        gf(k)=gf(k)*ratio;
    end
    g(i)=H-sum(gf);
    fprintf('%f\n',g(i));
    if g(i)>maxium
        maxium=g(i);
        maxi=i;
    end
end


if maxium<epsino
    H = hist(targets, length(cn)); %histogram of class
   [ma, largest] = max(H); %ma is the number of class who has largest number,largest is the posion in cn
   tree.pro=0;
   tree.value=cn(largest);
   tree.child=[];
   return
end


tree.pro=1;%1 represent it's a inner node,0 represents it's a leaf
tv=featurelabels(maxi);
tree.value=tv;
tree.child=[];
featurelabels(maxi)=[];


%split data according feature
[data,target,splitarr]=splitData(trainfeatures,targets,maxi);
%tree.child=zeros(1,length(data));
%build child tree;
fprintf('split data into %d\n',length(data));
for i=1:length(data)
   disp(data(i));
   fprintf('\n');
   disp(target(i));
   fprintf('\n');
end
fprintf('\n');


for i=1:size(data,1)
    result = zeros(size(data{i}));
    result=data{i};
    temptree=maketree(featurelabels,result,target{i},0);
    tree.pro=1;%1 represent it's a inner node,0 represents it's a leaf
    tree.value=tv;
    tree.child(i)=temptree;
    tree.child(i).parentpro = splitarr(i);
    fprintf('temp tree\n');
    disp(tree.child(1));
    fprintf('\n');
end
disp(tree);
fprintf("now root tree,tree has %d childs\n",size(tree.child,2));
fprintf('\n');
for i=1:size(data,1)
    disp(tree.child(i));
    fprintf('\n');
end
fprintf('one iteration ends\n');
end

3、根據某個特徵,將資料集分成若干子資料集

function [data,target,splitarr]=splitData(oldData,oldtarget,splitindex)
fn=unique(oldData(splitindex,:));
data=cell(length(fn),1);

target=cell(length(fn),1);
splitarr=zeros(size(fn));
for i=1:length(fn)
    fcolumn=find(oldData(splitindex,:)==fn(i));
    data(i) =oldData(:,fcolumn);
    target(i) = oldtarget(:,fcolumn);
    data{i}(splitindex,:)=[];
    splitarr(i)=fn(i);
end    
end

4、列印決策樹

function printTree(tree)
if tree.pro==0
    fprintf('(%d)',tree.value);
    if tree.parentpro~=-1
        fprintf('its parent feature value:%d\n',tree.parentpro);
    end
    return
end
fprintf('[%d]\n',tree.value);
if tree.parentpro~=-1
    fprintf('its parent feature value:%d\n',tree.parentpro);
end
fprintf('its subtree:\n');
childset = tree.child;
for i=1:size(childset,2)
    printTree(childset(i));
end
fprintf('\n');
fprintf('its subtree end\n');
end

5、對某個具體的樣本進行結果預測

function result=classify(data, tree)
while tree.pro==1
    childset=tree.child;
    v=tree.value;
    for i=1:size(childset,2)
        child = childset(i);
        if child.parentpro==data(v);
            tree=child;
            break;
        end
    end
end
result=tree.value;
end

接下來對資料用程式碼進行測試

clear all; close all; clc
featurelabels=[1,2,3,4];
trainfeatures=[1,1,1,1,1,2,2,2,2,2,3,3,3,3,3;%each row of trainfeature represent one feature and each column reprensent each examples 
                0,0,1,1,0,0,0,1,0,0,0,0,1,1,0;
                0,0,0,1,0,0,0,1,1,1,1,1,0,0,0;
                1,2,2,1,1,1,2,2,3,3,3,2,2,3,1
                ];
targets=[0,0,1,1,0,0,0,1,1,1,1,1,1,1,0];%represent classification results according to trainfeatures
tree=maketree(featurelabels,trainfeatures,targets,0);
printTree(tree);
data=[2,0,0,1];
result=classify(data,tree);
fprintf('The result is %d\n',result);

關於決策樹的原理構建大概就結束了,後期可以繼續完成對決策樹的剪枝或者將決策樹由多叉樹轉化為二叉樹,讓決策樹更加高效矮小。原始碼地址:https://github.com/summersunshine1/datamining。