1. 程式人生 > >機器學習(周志華) 參考答案 第四章 決策樹 python重寫版與畫樹演算法

機器學習(周志華) 參考答案 第四章 決策樹 python重寫版與畫樹演算法

機器學習(周志華西瓜書) 參考答案 總目錄

機器學習(周志華) 參考答案 第四章 決策樹

3.試程式設計實現基於資訊熵進行劃分選擇的決策樹演算法,併為表4.3中資料生成一棵決策樹。
最近在學著用python,所以用py重寫了以前的決策樹程式碼,在寫的過程中還是發現matlab在矩陣操作上要方便很多。

這次的程式碼為了更好的適用性直接用原文字作為輸入,在這裡規定如果樣本的某屬性值是數字則認定為連續型,其他認為是離散型。

輸入的第一行是每個屬性的標籤,第二行開始是每本個樣本具體的屬性值,每個屬性值由’,’隔開。
在函式傳入時有輸入是否有序號,用haveID標記
決策樹的屬性選擇使用了資訊增益,沒實現其他方法

程式碼與輸入壓縮包

輸入示例

編號,色澤,根蒂,敲聲,紋理,臍部,觸感,密度,含糖率,好壞
1,青綠,蜷縮,濁響,清晰,凹陷,硬滑,0.697,0.46,好瓜
2,烏黑,蜷縮,沉悶,清晰,凹陷,硬滑,0.744,0.376,好瓜
3,烏黑,蜷縮,濁響,清晰,凹陷,硬滑,0.634,0.264,好瓜
4,青綠,蜷縮,沉悶,清晰,凹陷,硬滑,0.608,0.318,好瓜
5,淺白,蜷縮,濁響,清晰,凹陷,硬滑,0.556,0.215,好瓜
6,青綠,稍蜷,濁響,清晰,稍凹,軟粘,0.403,0.237,好瓜
7,烏黑,稍蜷,濁響,稍糊,稍凹,軟粘,0.481,0.149,好瓜
8,烏黑,稍蜷,濁響,清晰,稍凹,硬滑,0.437,0.211,好瓜
9,烏黑,稍蜷,沉悶,稍糊,稍凹,硬滑,0.666,0.091,壞瓜
10,青綠,硬挺,清脆,清晰,平坦,軟粘,0.243,0.267,壞瓜
11,淺白,硬挺,清脆,模糊,平坦,硬滑,0.245,0.057,壞瓜
12,淺白,蜷縮,濁響,模糊,平坦,軟粘,0.343,0.099,壞瓜
13,青綠,稍蜷,濁響,稍糊,凹陷,硬滑,0.639,0.161,壞瓜
14,淺白,稍蜷,沉悶,稍糊,凹陷,硬滑,0.657,0.198,壞瓜
15,烏黑,稍蜷,濁響,清晰,稍凹,軟粘,0.36,0.37,壞瓜
16,淺白,蜷縮,濁響,模糊,平坦,硬滑,0.593,0.042,壞瓜
17,青綠,蜷縮,沉悶,稍糊,稍凹,硬滑,0.719,0.103,壞瓜

決策樹輸出樹的陣列結構,節點按樹的先根遍歷排序,每個節點4個屬性值,依次為

  1. 父節點的index,規定根節點該位置的index指向自己
  2. 如果節點是葉節點,記錄分類;如果是非葉節點,記錄當前位置最佳的劃分屬性標籤
  3. 記錄所有的非根節點連線父節點的具體屬性值,沒有則為空
  4. 如果該屬性是連續屬性,則記錄閾值,沒有則為空。

決策樹輸出示例

[[0, '紋理', [], []],
[0, '密度', '清晰', []], 
[1, '壞瓜', '小於', 0.3815], 
[1, '好瓜', '大於', 0.3815], 
[0, '觸感', '稍糊', []], 
[4, '壞瓜', '硬滑', []], 
[4
, '好瓜', '軟粘', []], [0, '壞瓜', '模糊', []]]

然後設計了一個m叉樹的繪製演算法,並改造成決策樹的繪製演算法(最好屬性的標籤在2個字,並沒有去根據屬性標籤的長度動態畫方框),繪圖工具使用的是matplotlib。

繪製方法採用層次遍歷,自底向上繪製節點,並要求:

  1. 所有節點與線段之間沒有重合
  2. 每個節點有個最小距離限制
  3. 如果是非葉節點,則它的位置在它子節點正中間
  4. 同一層的節點在同一水平線上(這條matlab畫的太蛋疼)

生成的樹與決策樹示例

通過西瓜資料集3得到的決策樹樹形

這裡寫圖片描述

通過西瓜資料集3得到的決策樹

這裡寫圖片描述

通過西瓜資料集2得到的決策樹

這裡寫圖片描述

這裡畫樹沒去注意樹的寬度,可以進一步美化,將相同父節點的節點按子節點順序排放,讓子節點少的節點更靠外。不過沒有去實現–

決策樹的遞迴演算法與屬性選擇演算法書上講的很詳細,這裡不再重複,具體內容看程式碼註釋

# -*- coding: utf-8 -*-
"""
Created on Mon Jan 16 12:01:08 2017

@author: icefire
"""

from dtreeplot import dtreeplot
import math

#屬性類
class property:
    def __init__(self,idnum,attribute): 
        self.is_continuity=False     #連續型屬性標記
        self.attribute=attribute     #屬性標籤
        self.subattributes=[]     #屬性子標籤
        self.id=idnum     #屬性排在輸入文字的第幾位
        self.index={}     #屬性子標籤的索引值

#決策樹生成類
class dtree():
    '''
    建構函式
    filename:輸入檔名
    haveID:輸入是否帶序號
    property_set:為空則計算全部屬性,否則記錄set中的屬性
    '''
    def __init__(self,filename,haveID,property_set):

        self.data=[]
        self.data_property=[]
        #讀入資料
        self.__dataread(filename,haveID)
        #判斷選擇的屬性集合
        if len(property_set)>0:
            tmp_data_property=[]
            for i in property_set:
                tmp_data_property.append(self.data_property[i])
            tmp_data_property.append(self.data_property[-1])
        else:
            tmp_data_property=self.data_property

        #決策樹樹形陣列結構
        self.treelink=[]

        #決策樹主遞迴
        self.__TreeGenerate(range(0,len(self.data[-1])),tmp_data_property,0,[],[])

        #決策樹繪製
        dtreeplot(self.treelink,6,1,-6)

    '''
    決策樹主遞迴
    data_set:當前樣本集合
    property_set:當前熟悉集合
    father:父節點索引值
    attribute:父節點連線當前節點的子屬性值
    threshold:如果是連續引數就是閾值,否則為空
    '''    
    def __TreeGenerate(self,data_set,property_set,father,attribute,threshold):
        #新增一個節點
        self.treelink.append([])
        #新節點的位置
        curnode=len(self.treelink)-1
        #記錄新節點的父親節點
        self.treelink[curnode].append(father)

        #結束條件1:所有樣本同一分類
        current_data_class=self.__count(data_set,property_set[-1])
        if(len(current_data_class)==1):
            self.treelink[curnode].append(self.data[-1][data_set[0]])
            self.treelink[curnode].append(attribute)
            self.treelink[curnode].append(threshold)
            return

        #結束條件2:所有樣本相同屬性,選擇分類數多的一類作為分類
        if all(len(self.__count(data_set,property_set[i]))==1 for i in range(0,len(property_set)-1)):
            max_count=-1;
            for dataclass in property_set[-1].subattributes:
                if current_data_class[dataclass]>max_count:
                    max_attribute=dataclass
                    max_count=current_data_class[dataclass] 
            self.treelink[curnode].append(max_attribute)
            self.treelink[curnode].append(attribute)
            self.treelink[curnode].append(threshold)
            return

        #資訊增益選擇最優屬性與閾值            
        prop,threshold = self.__entropy_paraselect(data_set,property_set)

        #記錄當前節點的最優屬性標籤與父節點連線當前節點的子屬性值     
        self.treelink[curnode].append(prop.attribute)
        self.treelink[curnode].append(attribute)

        #從屬性集合中移除當前屬性
        property_set.remove(prop)

        #判斷是否是連續屬性
        if(prop.is_continuity):
            #連續屬性分為2子屬性,大於和小於
            tmp_data_set=[[],[]]
            for i in data_set:
                tmp_data_set[self.data[prop.id][i]>threshold].append(i)
            for i in [0,1]:
                self.__TreeGenerate(tmp_data_set[i],property_set[:],curnode,prop.subattributes[i],threshold)
        else:
            #離散屬性有多子屬性
            tmp_data_set=[[] for i in range(0,len(prop.subattributes))]
            for i in data_set:
                tmp_data_set[prop.index[self.data[prop.id][i]]].append(i)

            for i in range(0,len(prop.subattributes)):
                if len(tmp_data_set[i])>0:
                    self.__TreeGenerate(tmp_data_set[i],property_set[:],curnode,prop.subattributes[i],[])
                else:
                    #如果某一個子屬性不存沒有對應的樣本,則選擇父節點分類更多的一項作為分類
                    self.treelink.append([])
                    max_count=-1;
                    tnode=len(self.treelink)-1
                    for dataclass in property_set[-1].subattributes:
                        if current_data_class[dataclass]>max_count:
                            max_attribute=dataclass
                            max_count=current_data_class[dataclass]
                    self.treelink[tnode].append(curnode)
                    self.treelink[tnode].append(max_attribute)
                    self.treelink[tnode].append(prop.subattributes[i])
                    self.treelink[tnode].append(threshold)    

        #為沒有4個值得節點用空列表補齊4個值                
        for i in range(len(self.treelink[curnode]),4):
            self.treelink[curnode].append([])

    '''
    資訊增益算則最佳屬性
    data_set:當前樣本集合
    property_set:當前屬性集合
    '''        
    def __entropy_paraselect(self,data_set,property_set):
        #分離散和連續型分別計算資訊增益,選擇最大的一個
        max_ent=-10000
        for i in range(0,len(property_set)-1):
            prop_id=property_set[i].id
            if(property_set[i].is_continuity):
                   tmax_ent=-10000 
                   xlist=self.data[prop_id][:]
                   xlist.sort()
                   #連續型求出相鄰大小值的平局值作為待選的最佳閾值
                   for j in range(0,len(xlist)-1):
                       xlist[j]=(xlist[j+1]+xlist[j])/2
                   for j in range(0,len(xlist)-1):                   
                       if(i>0 and xlist[j]==xlist[j-1]):
                           continue
                       cur_ent = 0
                       nums=[[0,0],[0,0]]
                       for k in data_set:
                           nums[self.data[prop_id][k]>xlist[j]][property_set[-1].index[self.data[-1][k]]]+=1
                       for k in [0,1]:
                           subattribute_sum=nums[k][0]+nums[k][1]
                           if(subattribute_sum > 0):
                               p=nums[k][0]/subattribute_sum
                               cur_ent +=(p*math.log(p+0.00001,2)+(1-p)*math.log(1-p+0.00001,2))*subattribute_sum/len(data_set)
                       if(cur_ent>tmax_ent):
                           tmax_ent = cur_ent
                           tmp_threshold = xlist[j]
                   if(tmax_ent > max_ent):
                        max_ent=tmax_ent;
                        bestprop = property_set[i];
                        best_threshold = tmp_threshold;
            else:
                #直接統計並計算
                cur_ent=0
                nums=[[0,0] for i in range(0,len(property_set[i].subattributes))]
                for j in data_set:
                    nums[property_set[i].index[self.data[prop_id][j]]][property_set[-1].index[self.data[-1][j]]]+=1
                for j in range(0,len(property_set[i].subattributes)):
                    subattribute_sum=nums[j][0]+nums[j][1]
                    if(subattribute_sum>0):
                        p=nums[j][0]/subattribute_sum
                        cur_ent += (p*math.log(p+0.00001,2)+(1-p)*math.log(1-p+0.00001,2))*subattribute_sum/len(data_set)
                if(cur_ent > max_ent):
                    max_ent=cur_ent;
                    bestprop = property_set[i];
                    best_threshold = [];                                   

        return bestprop,best_threshold

    '''
    計算當前樣本在某個屬性下的分類情況
    '''    
    def __count(self,data_set,prop):
        out={}

        rowdata=self.data[prop.id]
        for i in data_set:
            if rowdata[i] in out:
                out[rowdata[i]]+=1
            else:
                out[rowdata[i]]=1;

        return out

    '''
    輸入資料處理
    '''        
    def __dataread(self,filename,haveID):
        file = open(filename, 'r')
        linelen=0
        first=1;    
        while 1:
            #按行讀
            line = file.readline()

            if not line:
                break

            line=line.strip('\n')
            rowdata = line.split(',')
            #如果有編號就去掉第一列
            if haveID:
                del rowdata[0]

            if(linelen==0):
                #處理第一行,初始化屬性類物件,記錄屬性的標籤
                for i in range(0,len(rowdata)):
                    self.data.append([])
                    self.data_property.append(property(i,rowdata[i]))
                    self.data_property[i].attribute=rowdata[i]
                linelen=len(rowdata)
            elif(linelen==len(rowdata)):
                if(first==1):
                    #處理第二行,記錄屬性是否是連續型和子屬性
                    for i in range(0,len(rowdata)):           
                        if(isnumeric(rowdata[i])):
                            self.data_property[i].is_continuity=True
                            self.data[i].append(float(rowdata[i]))
                            self.data_property[i].subattributes.append("小於")
                            self.data_property[i].index["小於"]=0
                            self.data_property[i].subattributes.append("大於")
                            self.data_property[i].index["大於"]=1
                        else:
                            self.data[i].append(rowdata[i])
                else:
                    #處理後面行,記錄子屬性
                    for i in range(0,len(rowdata)):         
                        if(self.data_property[i].is_continuity):
                            self.data[i].append(float(rowdata[i]))
                        else:
                            self.data[i].append(rowdata[i])
                            if rowdata[i] not in self.data_property[i].subattributes:
                                self.data_property[i].subattributes.append(rowdata[i])
                                self.data_property[i].index[rowdata[i]]=len(self.data_property[i].subattributes)-1
                first=0           
            else:
                continue
'''
判斷是否是數字
'''
def isnumeric(s):
    return all(c in "0123456789.-" for c in s)

filename="西瓜3.data"            
property_set=range(0,6)
link=dtree(filename,True,property_set)

下面是決策樹的繪製類,步驟是:

  1. 處理決策樹生成的陣列結構,生成樹形結構與層次結構
  2. 根據層次結構從下往上繪製
  3. 對於每層,首先求非葉葉節點的初始位置,值為它最邊緣兩個子節點中間;然後計算非葉節點
  4. 對於每層第一個非葉節點前和最後一個非葉節點後的節點,直接用最小距離繪製。
  5. 對於兩個非葉節點中間的葉節點,如果兩個非葉節點中間足夠大,則中間的葉節點均勻分佈,否則按最小距離繪製,並記錄非葉節點的新位置,計算出相對於初始位置的偏移。

程式碼如下:

# -*- coding: utf-8 -*-
"""
Created on Sun Jan 15 10:19:20 2017

@author: icefire
"""

import numpy as np
from matplotlib import pyplot as plt

'''
樹的節點類
data:樹的陣列結構的一項,4值
height:節點的高
'''
class treenode:
    def __init__(self,data,height):
        self.father=data[0]     #父節點
        self.children=[]     #子節點列表
        self.data=data[1]    #節點標籤
        self.height=height
        self.pos=0;         #節點計算時最終位置,計算時只儲存相對位置
        self.offset=0;       #節點最終位置與初始位置的相對值
        self.data_to_father=data[2]     #連結父節點的屬性值
        #如果有閾值,則加入閾值
        if type(data[3])!=list:
            self.data_to_father=self.data_to_father+str(data[3]);

'''
樹的繪製類
link:樹的陣列結構
minspace:節點間的距離
r:節點的繪製半徑
lh:層高
'''        
class dtreeplot:
    def __init__(self,link,minspace,r,lh):

        s=len(link)
        #所有節點的列表,第一項為根節點
        treenodelist=[]
        #節點的層次結構
        treelevel=[]

        #處理樹的陣列結構
        for i in range(0,s):
            #根節點的index與其父節點的index相同
            if link[i][0]==i:
                treenodelist.append(treenode(link[i],0))
            else:
                treenodelist.append(treenode(link[i],treenodelist[link[i][0]].height+1))
                treenodelist[link[i][0]].children.append(treenodelist[i]);
                treenodelist[i].father=treenodelist[link[i][0]];
            #如果有新一層的節點則新建一層
            if len(treelevel)==treenodelist[i].height:
                treelevel.append([])
            treelevel[treenodelist[i].height].append(treenodelist[i])

        #控制繪製圖像的座標軸
        self.right=0
        self.left=0
        #反轉層次,從底往上畫
        treelevel.reverse()  
        #計算每個節點的位置
        self.__calpos(treelevel,minspace)   
        #繪製樹形                          
        self.__drawtree(treenodelist[0] ,r,lh,0)
        plt.xlim(xmin=self.left,xmax=self.right+minspace)
        plt.ylim(ymin=len(treelevel)*lh+lh/2,ymax=lh/2)
        plt.show()

    '''
    逐一繪製計算每個節點的位置
    nodes:節點集合
    l,r:左右區間
    start:當前層的初始繪製位置
    minspace:使用的最小間距
    '''      
    def __calonebyone(self,nodes,l,r,start,minspace):
        for i in range(l,r):
                nodes[i].pos=max(nodes[i].pos,start)
                start=nodes[i].pos+minspace;
        return start;

    '''
    計算每個節點的位置與相對偏移
    treelevel:樹的層次結構
    minspace:使用的最小間距
    '''
    def __calpos(self,treelevel,minspace):
        #按層次畫
        for nodes in treelevel:
            #記錄非葉節點
            noleaf=[]
            num=0;
            for node in nodes:
                if len(node.children)>0:
                    noleaf.append(num)
                    node.pos=(node.children[0].pos+node.children[-1].pos)/2
                num=num+1

            start=minspace

            #如果全是非葉節點,直接繪製                        
            if(len(noleaf))==0:
                self.__calonebyone(nodes,0,len(nodes),0,minspace)
            else:
                start=nodes[noleaf[0]].pos-noleaf[0]*minspace
                self.left=min(self.left,start-minspace)
                start=self.__calonebyone(nodes,0,noleaf[0],start,minspace)
                for i in range(0,len(noleaf)):
                    nodes[noleaf[i]].offset=max(nodes[noleaf[i]].pos,start)-nodes[noleaf[i]].pos
                    nodes[noleaf[i]].pos=max(nodes[noleaf[i]].pos,start)

                    if(i<len(noleaf)-1):
                            #計算兩個非葉節點中間的間隔,如果足夠大就均勻繪製
                            dis=(nodes[noleaf[i+1]].pos-nodes[noleaf[i]].pos)/(noleaf[i+1]-noleaf[i])
                            start=nodes[noleaf[i]].pos+max(minspace,dis)
                            start=self.__calonebyone(nodes,noleaf[i]+1,noleaf[i+1],start,max(minspace,dis))
                    else:
                         start=nodes[noleaf[i]].pos+minspace
                         start=self.__calonebyone(nodes,noleaf[i]+1,len(nodes),start,minspace)

    '''
    採用先根遍歷繪製樹
    treenode:當前遍歷的節點
    r:半徑
    lh:層高
    curoffset:每層節點的累計偏移
    '''                     
    def __drawtree(self,treenode,r,lh,curoffset):
         #加上當前的累計偏差得到最終位置
         treenode.pos=treenode.pos+curoffset

         if(treenode.pos>self.right):
             self.right=treenode.pos

         #如果是葉節點則畫圈,非葉節點畫方框
         if(len(treenode.children)>0):
             drawrect(treenode.pos,(treenode.height+1)*lh,r)
             plt.text(treenode.pos, (treenode.height+1)*lh, treenode.data+'=?', color=(0,0,1),ha='center', va='center')
         else:
             drawcircle(treenode.pos,(treenode.height+1)*lh,r)
             plt.text(treenode.pos, (treenode.height+1)*lh, treenode.data, color=(1,0,0),ha='center', va='center')

         num=0;
         #先根遍歷
         for node in treenode.children:
             self.__drawtree(node,r,lh,curoffset+treenode.offset)

             #繪製父節點到子節點的連線
             num=num+1

             px=(treenode.pos-r)+2*r*num/(len(treenode.children)+1)
             py=(treenode.height+1)*lh-r-0.02

             #按到圓到方框分開畫
             if(len(node.children)>0):
                 px1=node.pos
                 py1=(node.height+1)*lh+r
                 off=np.array([px-px1,py-py1])
                 off=off*r/np.linalg.norm(off)

             else:
                 off=np.array([px-node.pos,-lh+1])
                 off=off*r/np.linalg.norm(off)
                 px1=node.pos+off[0]
                 py1=(node.height+1)*lh+off[1]

             #計算父節點與子節點連線的方向與角度
             plt.plot([px,px1],[py,py1],color=(0,0,0))  
             pmx=(px1+px)/2-(1-2*(px<px1))*0.4     
             pmy=(py1+py)/2+0.4
             arc=np.arctan(off[1]/(off[0]+0.0000001))
             #繪製文字以及旋轉
             plt.text(pmx,pmy, node.data_to_father, color=(1,0,1),ha='center', va='center', rotation=arc/np.pi*180)

'''
畫圓
'''          
def drawcircle(x,y,r):
     theta = np.arange(0, 2 * np.pi, 2 * np.pi / 1000)
     theta = np.append(theta, [2 * np.pi])
     x1=[]
     y1=[]
     for tha in theta:
         x1.append(x + r * np.cos(tha))
         y1.append(y + r * np.sin(tha))
     plt.plot(x1, y1,color=(0,0,0))

'''
畫矩形
'''       
def drawrect(x,y,r):
     x1=[x-r,x+r,x+r,x-r,x-r]
     y1=[y-r,y-r,y+r,y+r,y-r]
     plt.plot(x1, y1,color=(0,0,0))