1. 程式人生 > >機器學習演算法的Python實現 (3):決策樹剪枝處理

機器學習演算法的Python實現 (3):決策樹剪枝處理

更新,經評論提醒,我發現自己搞錯了比較根本的定義。CART決策樹假設決策樹是二叉樹,這裡給出的程式碼生成的決策樹不是二叉樹。所以下面的程式碼用”基於基尼指數生成的決策樹“來形容更加適當一點。

-------------------------------------------

本文資料參照 機器學習-周志華 一書中的決策樹一章。可作為此章課後習題4的答案

程式碼則參照《機器學習實戰》一書的內容,並做了一些修改。

CART決策樹 使用基尼指數(Gini Index)來選擇劃分屬性。其公式如下:

本文內容包括未剪枝CART決策樹、預剪枝CART決策樹以及後剪枝決策樹

本文使用的Python庫包括

  • numpy
  • pandas
  • math
  • operator
  • matplotlib
  • copy
  • re
本文使用的資料如下
Idx color root knocks texture navel touch label
1 dark_green curl_up little_heavily distinct sinking hard_smooth 1
2 black curl_up heavily distinct sinking hard_smooth 1
3 black curl_up little_heavily distinct sinking hard_smooth 1
6 dark_green little_curl_up
little_heavily distinct little_sinking soft_stick 1
7 black little_curl_up little_heavily little_blur little_sinking soft_stick 1
10 dark_green stiff clear distinct even soft_stick 0
14 light_white little_curl_up heavily little_blur sinking hard_smooth 0
15 black little_curl_up little_heavily distinct
little_sinking soft_stick 0
16 light_white curl_up little_heavily blur even hard_smooth 0
17 dark_green curl_up heavily little_blur little_sinking hard_smooth 0
4 dark_green curl_up heavily distinct sinking hard_smooth 1
5 light_white curl_up little_heavily distinct sinking hard_smooth 1
8 black little_curl_up little_heavily distinct little_sinking hard_smooth 1
9 black little_curl_up heavily little_blur little_sinking hard_smooth 0
11 light_white stiff clear blur even hard_smooth 0
12 light_white curl_up little_heavily blur even soft_stick 0
13 dark_green little_curl_up little_heavily little_blur sinking hard_smooth 0
其中,前11個數據用作訓練集(1,2,3,6,7,10,14,15,16,17,4)後6個數據用作測試集(5,8,9,11,12,13) 注:書上原來是將前10個數據作為訓練,後7個數據測試。但是這樣得到的決策樹與書上例子不同。而如上調整後得到的結果相同,因此應該是書上的資料集劃分與其實際使用的劃分不同。為了起到對照作用,本文采取書上圖例中決策樹對應的劃分 未剪枝決策樹: 程式碼與ID3決策樹相似,主要將資訊增益換成了基尼指數的計算。
# -*- coding: utf-8 -*-


from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator

#計算資料集的基尼指數
def calcGini(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #給所有可能分類建立字典
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    Gini=1.0
    #以2為底數計算夏農熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        Gini-=prob*prob
    return Gini


#對離散變數劃分資料集,取出該特徵取值為value的所有樣本
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


#對連續變數劃分資料集,direction規定劃分的方向,
#決定是劃分出小於value的資料樣本還是大於value的資料樣本集
def splitContinuousDataSet(dataSet,axis,value,direction):
    retDataSet=[]
    for featVec in dataSet:
        if direction==0:
            if featVec[axis]>value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis]<=value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet


#選擇最好的資料集劃分方式
def chooseBestFeatureToSplit(dataSet,labels):
    numFeatures=len(dataSet[0])-1
    bestGiniIndex=100000.0
    bestFeature=-1
    bestSplitDict={}
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        #對連續型特徵進行處理
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
            #產生n-1個候選劃分點
            sortfeatList=sorted(featList)
            splitList=[]
            for j in range(len(sortfeatList)-1):
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)
            
            bestSplitGini=10000
            slen=len(splitList)
            #求用第j個候選劃分點劃分時,得到的資訊熵,並記錄最佳劃分點
            for j in range(slen):
                value=splitList[j]
                newGiniIndex=0.0
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
                prob0=len(subDataSet0)/float(len(dataSet))
                newGiniIndex+=prob0*calcGini(subDataSet0)
                prob1=len(subDataSet1)/float(len(dataSet))
                newGiniIndex+=prob1*calcGini(subDataSet1)
                if newGiniIndex<bestSplitGini:
                    bestSplitGini=newGiniIndex
                    bestSplit=j
            #用字典記錄當前特徵的最佳劃分點
            bestSplitDict[labels[i]]=splitList[bestSplit]
            
            GiniIndex=bestSplitGini
        #對離散型特徵進行處理
        else:
            uniqueVals=set(featList)
            newGiniIndex=0.0
            #計算該特徵下每種劃分的資訊熵
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value)
                prob=len(subDataSet)/float(len(dataSet))
                newGiniIndex+=prob*calcGini(subDataSet)
            GiniIndex=newGiniIndex
        if GiniIndex<bestGiniIndex:
            bestGiniIndex=GiniIndex
            bestFeature=i
    #若當前節點的最佳劃分特徵為連續特徵,則將其以之前記錄的劃分點為界進行二值化處理
    #即是否小於等於bestSplitValue
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':      
        bestSplitValue=bestSplitDict[labels[bestFeature]]        
        labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature]<=bestSplitValue:
                dataSet[i][bestFeature]=1
            else:
                dataSet[i][bestFeature]=0
    return bestFeature


#特徵若已經劃分完,節點下的樣本還沒有統一取值,則需要進行投票
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    return max(classCount)


#主程式,遞迴產生決策樹
def createTree(dataSet,labels,data_full,labels_full):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    if type(dataSet[0][bestFeat]).__name__=='str':
        currentlabel=labels_full.index(labels[bestFeat])
        featValuesFull=[example[currentlabel] for example in data_full]
        uniqueValsFull=set(featValuesFull)
    del(labels[bestFeat])
    #針對bestFeat的每個取值,劃分出一個子樹。
    for value in uniqueVals:
        subLabels=labels[:]
        if type(dataSet[0][bestFeat]).__name__=='str':
            uniqueValsFull.remove(value)
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
         (dataSet,bestFeat,value),subLabels,data_full,labels_full)
    if type(dataSet[0][bestFeat]).__name__=='str':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value]=majorityCnt(classList)
    return myTree


df=pd.read_csv('watermelon_4_2.csv')
data=df.values[:11,1:].tolist()
data_full=data[:]
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full)


import plotTree
plotTree.createPlot(myTree)

plotTree的程式見我之前的ID3決策樹博文: 得到的決策樹結果為: 與書上的圖4.5一致 接下來進行剪枝操作 預剪枝決策樹: 預剪枝是在決策樹生成過程中,在劃分節點時,若該節點的劃分沒有提高其在訓練集上的準確率,則不進行劃分。 程式碼如下:
# -*- coding: utf-8 -*-


from numpy import *
import numpy as np
import pandas as pd
from math import log
import operator
import copy
import re


#計算資料集的基尼指數
def calcGini(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #給所有可能分類建立字典
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    Gini=1.0
    #以2為底數計算夏農熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        Gini-=prob*prob
    return Gini


#對離散變數劃分資料集,取出該特徵取值為value的所有樣本
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


#對連續變數劃分資料集,direction規定劃分的方向,
#決定是劃分出小於value的資料樣本還是大於value的資料樣本集
def splitContinuousDataSet(dataSet,axis,value,direction):
    retDataSet=[]
    for featVec in dataSet:
        if direction==0:
            if featVec[axis]>value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        else:
            if featVec[axis]<=value:
                reducedFeatVec=featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    return retDataSet


#選擇最好的資料集劃分方式
def chooseBestFeatureToSplit(dataSet,labels):
    numFeatures=len(dataSet[0])-1
    bestGiniIndex=100000.0
    bestFeature=-1
    bestSplitDict={}
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        #對連續型特徵進行處理
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
            #產生n-1個候選劃分點
            sortfeatList=sorted(featList)
            splitList=[]
            for j in range(len(sortfeatList)-1):
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)
            
            bestSplitGini=10000
            slen=len(splitList)
            #求用第j個候選劃分點劃分時,得到的資訊熵,並記錄最佳劃分點
            for j in range(slen):
                value=splitList[j]
                newGiniIndex=0.0
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
                prob0=len(subDataSet0)/float(len(dataSet))
                newGiniIndex+=prob0*calcGini(subDataSet0)
                prob1=len(subDataSet1)/float(len(dataSet))
                newGiniIndex+=prob1*calcGini(subDataSet1)
                if newGiniIndex<bestSplitGini:
                    bestSplitGini=newGiniIndex
                    bestSplit=j
            #用字典記錄當前特徵的最佳劃分點
            bestSplitDict[labels[i]]=splitList[bestSplit]
            
            GiniIndex=bestSplitGini
        #對離散型特徵進行處理
        else:
            uniqueVals=set(featList)
            newGiniIndex=0.0
            #計算該特徵下每種劃分的資訊熵
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value)
                prob=len(subDataSet)/float(len(dataSet))
                newGiniIndex+=prob*calcGini(subDataSet)
            GiniIndex=newGiniIndex
        if GiniIndex<bestGiniIndex:
            bestGiniIndex=GiniIndex
            bestFeature=i
    #若當前節點的最佳劃分特徵為連續特徵,則將其以之前記錄的劃分點為界進行二值化處理
    #即是否小於等於bestSplitValue
    #並將特徵名改為 name<=value的格式
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':      
        bestSplitValue=bestSplitDict[labels[bestFeature]]        
        labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
        for i in range(shape(dataSet)[0]):
            if dataSet[i][bestFeature]<=bestSplitValue:
                dataSet[i][bestFeature]=1
            else:
                dataSet[i][bestFeature]=0
    return bestFeature


#特徵若已經劃分完,節點下的樣本還沒有統一取值,則需要進行投票
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    return max(classCount)


#由於在Tree中,連續值特徵的名稱以及改為了  feature<=value的形式
#因此對於這類特徵,需要利用正則表示式進行分割,獲得特徵名以及分割閾值
def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0]
    if '<=' in firstStr:
        featvalue=float(re.compile("(<=.+)").search(firstStr).group()[2:])
        featkey=re.compile("(.+<=)").search(firstStr).group()[:-2]
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(featkey)
        if testVec[featIndex]<=featvalue:
            judge=1
        else:
            judge=0
        for key in secondDict.keys():
            if judge==int(key):
                if type(secondDict[key]).__name__=='dict':
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
    else:
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex]==key:
                if type(secondDict[key]).__name__=='dict':
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
    return classLabel


def testing(myTree,data_test,labels):
    error=0.0
    for i in range(len(data_test)):
        if classify(myTree,labels,data_test[i])!=data_test[i][-1]:
            error+=1
    print 'myTree %d' %error
    return float(error)
    
def testingMajor(major,data_test):
    error=0.0
    for i in range(len(data_test)):
        if major!=data_test[i][-1]:
            error+=1
    print 'major %d' %error
    return float(error)


#主程式,遞迴產生決策樹
def createTree(dataSet,labels,data_full,labels_full,data_test):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    temp_labels=copy.deepcopy(labels)
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    if type(dataSet[0][bestFeat]).__name__=='str':
        currentlabel=labels_full.index(labels[bestFeat])
        featValuesFull=[example[currentlabel] for example in data_full]
        uniqueValsFull=set(featValuesFull)
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    del(labels[bestFeat])
    #針對bestFeat的每個取值,劃分出一個子樹。
    for value in uniqueVals:
        subLabels=labels[:]
        if type(dataSet[0][bestFeat]).__name__=='str':
            uniqueValsFull.remove(value)
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
         (dataSet,bestFeat,value),subLabels,data_full,labels_full,\
         splitDataSet(data_test,bestFeat,value))
    if type(dataSet[0][bestFeat]).__name__=='str':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value]=majorityCnt(classList)
    #進行測試,若劃分沒有提高準確率,則不進行劃分,返回該節點的投票值作為節點類別
    
    if testing(myTree,data_test,temp_labels)<testingMajor(majorityCnt(classList),data_test):
        return myTree
    return majorityCnt(classList)


df=pd.read_csv('watermelon_4_2.csv')
data=df.values[:11,1:].tolist()
data_full=data[:]
data_test=df.values[11:,1:].tolist()
labels=df.columns.values[1:-1].tolist()
labels_full=labels[:]
myTree=createTree(data,labels,data_full,labels_full,data_test)


import plotTree
plotTree.createPlot(myTree)

得到的結果如圖所示: 與書上的圖4.6一致 後剪枝決策樹: 後剪枝決策樹先生成一棵完整的決策樹,再從底往頂進行剪枝處理。在以下程式碼中,使用的是深度優先搜尋。 決策樹生成部分與未剪枝CART決策樹相同,在其程式碼最後附上以下後剪枝程式碼即可:
#由於在Tree中,連續值特徵的名稱以及改為了  feature<=value的形式
#因此對於這類特徵,需要利用正則表示式進行分割,獲得特徵名以及分割閾值
def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0]
    if '<=' in firstStr:
        featvalue=float(re.compile("(<=.+)").search(firstStr).group()[2:])
        featkey=re.compile("(.+<=)").search(firstStr).group()[:-2]
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(featkey)
        if testVec[featIndex]<=featvalue:
            judge=1
        else:
            judge=0
        for key in secondDict.keys():
            if judge==int(key):
                if type(secondDict[key]).__name__=='dict':
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
    else:
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex]==key:
                if type(secondDict[key]).__name__=='dict':
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
    return classLabel
#測試決策樹正確率
def testing(myTree,data_test,labels):
    error=0.0
    for i in range(len(data_test)):
        if classify(myTree,labels,data_test[i])!=data_test[i][-1]:
            error+=1
    #print 'myTree %d' %error
    return float(error)
#測試投票節點正確率
def testingMajor(major,data_test):
    error=0.0
    for i in range(len(data_test)):
        if major!=data_test[i][-1]:
            error+=1
    #print 'major %d' %error
    return float(error)
#後剪枝
def postPruningTree(inputTree,dataSet,data_test,labels):
    firstStr=inputTree.keys()[0]
    secondDict=inputTree[firstStr]
    classList=[example[-1] for example in dataSet]
    featkey=copy.deepcopy(firstStr)
    if '<=' in firstStr:
        featkey=re.compile("(.+<=)").search(firstStr).group()[:-2]
        featvalue=float(re.compile("(<=.+)").search(firstStr).group()[2:])
    labelIndex=labels.index(featkey)
    temp_labels=copy.deepcopy(labels)
    del(labels[labelIndex])
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            if type(dataSet[0][labelIndex]).__name__=='str':
                inputTree[firstStr][key]=postPruningTree(secondDict[key],\
                 splitDataSet(dataSet,labelIndex,key),splitDataSet(data_test,labelIndex,key),copy.deepcopy(labels))
            else:
                inputTree[firstStr][key]=postPruningTree(secondDict[key],\
                splitContinuousDataSet(dataSet,labelIndex,featvalue,key),\
                splitContinuousDataSet(data_test,labelIndex,featvalue,key),\
                copy.deepcopy(labels))
    if testing(inputTree,data_test,temp_labels)<=testingMajor(majorityCnt(classList),data_test):
        return inputTree
    return majorityCnt(classList)


data=df.values[:11,1:].tolist()
data_test=df.values[11:,1:].tolist()
labels=df.columns.values[1:-1].tolist()
myTree=postPruningTree(myTree,data,data_test,labels)


import plotTree
plotTree.createPlot(myTree)
得到的決策樹如下圖: 與教材P83的圖4.7一致。 總結: 後剪枝決策樹比起預剪枝決策樹保留了更多的分支。在一般情形下,後剪枝決策樹的欠擬合風險很小,泛化效能往往優於預剪枝決策樹。但同時其訓練時間花銷也比較大。