1. 程式人生 > >《機器學習實戰》學習(二)——決策樹(DT)

《機器學習實戰》學習(二)——決策樹(DT)

1、決策樹簡述

決策樹學習是一種逼近離散值目標函式的方法,在這種方法中學習到的函式被表示為一棵決策樹。在周志華老師的《機器學習》這本書中專門一章節對決策樹進行了講述。並對id3演算法後的改進演算法也做了相應的介紹。決策樹容易導致過擬合現象,介紹了預剪枝和後剪枝等相關的處理方法。決策樹依賴測試集,可以把測試集生成的樹結構序列化存到檔案中,下次使用可以很快進行載入。
一個牛人對決策樹的總結,我覺得很有道理,所以原文放在這裡,總結如下:
據我瞭解,決策樹是最簡單,也是曾經最常用的分類方法了。決策樹基於樹理論實現資料分類,個人感覺就是資料結構中的B+樹。決策樹是一個預測模型,他代表的是物件屬性與物件值之間的一種對映關係。決策樹計算複雜度不高、輸出結果易於理解、對中間值缺失不敏感、可以處理不相關特徵資料。其比KNN好的是可以瞭解資料的內在含義。但其缺點是容易產生過度匹配的問題,且構建很耗時。決策樹還有一個問題就是,如果不繪製樹結構,分類細節很難明白。所以,生成決策樹,然後再繪製決策樹,最後再分類,才能更好的瞭解資料的分類過程。
決策樹的核心樹的分裂。到底該選擇什麼來決定樹的分叉是決策樹構建的基礎。最好的方法是利用資訊熵實現。熵這個概念很頭疼,很容易讓人迷糊,簡單來說就是資訊的複雜程度。資訊越多,熵越高。所以決策樹的核心是通過計算資訊熵劃分資料集。
來源於:

http://www.cnblogs.com/zhizhan/p/4432943.html

2、ID3生成一個決策樹python程式碼註釋

# -*- coding: utf-8 -*-
"""
@brief 計算給定資料集的資訊熵
@param dataSet 資料集
@return 夏農熵
"""
import operator
from math import log
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)#求取資料集的行數
    labelCounts = {}
    for featVec in dataSet:#讀取資料集中的一行資料
currentLabel = featVec[-1] #取featVec中最後一列的值 #以一行資料中的最後一列值為鍵值進行統計 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries#將每一類求取概率
shannonEnt -= prob * log(prob,2)#求取資料集的夏農熵 return shannonEnt """ @brief 建立臨時測試集 @param @return dataSet 返回一個測試資料集 @return labels 返回資料集的標籤 """ def createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing','flippers'] return dataSet,labels """ @brief 劃分資料集 按照給定的特徵劃分資料集 @param[in] dataSet 待劃分的資料集 @param[in] axis 劃分資料集的特徵 @param[in] value 需要返回的特徵的值 @return retDataSet 返回劃分後的資料集 """ def splitDataSet(dataSet, axis, value): retDataSet = []#返回的劃分後的資料集 for featVec in dataSet: #抽取符合劃分特徵的值 if featVec[axis] == value: #如何符合此特徵值 則儲存,儲存劃分後的資料集時 不需要儲存選為劃分的特徵 reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1 :]) retDataSet.append(reducedFeatVec) return retDataSet """ @brief 遍歷整個資料集,迴圈計算夏農熵和選擇劃分函式,找到最好的劃分方式。 @param[in] dataSet 整個特徵集 待選擇的集 @return bestFeature 劃分資料集最好的劃分特徵列的索引值 """ def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 #計算資料集中特徵數目 baseEntropy = calcShannonEnt(dataSet) #計算資料集的夏農熵 bestInfoGain = 0.0 #資訊增益初始值 bestFeature = -1#初始化最佳特徵值 for i in range(numFeatures): featList = [example[i] for example in dataSet] #提取資料集中特徵值 i表示列數 uniqueVals = set(featList) #獲取本列中的特徵值集合 去除列中重複元素 newEntropy = 0.0 #劃分資料後 資料集的夏農熵 #計算每一個特徵值進行劃分 產生的子集的資訊熵 然後將各個子集的資訊熵按照比值求和 for value in uniqueVals : subDataSet = splitDataSet(dataSet,i,value) prob = len(subDataSet)/float(len(dataSet)) #q求取子集的比值 newEntropy += prob * calcShannonEnt(subDataSet) #計算每個資訊增益 infoGain = baseEntropy - newEntropy if(infoGain > bestInfoGain):#獲得最大資訊增益值以及特徵值列的索引值 bestInfoGain = infoGain bestFeature = i return bestFeature """ @brief 計算一個特徵資料列表中 出現次數最多的特徵值以及次數 @param[in] 特徵值列表 @return 返回次數最多的特徵值 例如:[1,1,0,1,1]資料列表 返回 1 """ def majorityCnt(classList): classCount = {} #統計資料列表中每個特徵值出現的次數 for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 #根據出現的次數進行排序 key=operator.itemgetter(1) 意思是按照次數進行排序 #classCount.items() 轉換為資料字典 進行排序 reverse = True 表示由大到小排序 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse = True) #返回次數最多的一項的特徵值 return sortedClassCount[0][0] """ @brief 遞迴建立一顆樹 @param[in] dataSet 資料集 @param[in] labels 標籤資料 @return myTree 返回數結構 使用字典型別儲存樹結構 """ def createTree(dataSet,labels): classList = [example[-1] for example in dataSet]#獲取資料集中的最後一列 #如果類別完全相同,則停止劃分建立 if classList.count(classList[0]) == len(classList): return classList[0] #如果第一行資料長度為1 表示已經遍歷完所有特徵,則返回出現次數最多的特徵值 if len(dataSet[0]) == 1: return majorityCnt(classList) #獲得資料集的最好劃分特徵值列索引 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree ={bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree """ @brief 對未知特徵在建立的決策樹上進行分類 @param[in] inputTree @param[in] featLabels @param[in] testVec @return classLabel 返回識別的結果 """ def classify(inputTree,featLabels,testVec): firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key : if isinstance(secondDict[key],dict) == True: classLabel = classify(secondDict[key],featLabels,testVec) else: classLabel = secondDict[key] return classLabel """ @brief 儲存構建的決策樹 """ def storeTree(inputTree,filename): import pickle fw = open(filename,'wb') pickle.dump(inputTree,fw) fw.close() """ @brief 讀取文字儲存的決策樹 """ def grabTree(filename): import pickle fr = open(filename,'rb') return pickle.load(fr)

3、使用Matplotlib繪製決策樹python程式碼

import matplotlib.pyplot as plt
#定義決策節點和葉子節點的風格
decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
#boxstyle = "swatooth"意思是註解框的邊緣是波浪線型的,fc控制的註解框內的顏色深度
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")#箭頭符號
"""
@brief 繪製節點
@param[in] nodeTxt 節點顯示文字
@param[in] centerPt 起點位置
@param[in] parentPt 終點位置
@param[in] nodeType 節點風格
"""
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
    xytext=centerPt,textcoords='axes fraction',\
    va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
"""
@brief 修改createPlot函式
"""

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])

    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops) #繪製子圖
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

"""
@brief 獲取葉節點的數目 採用遞迴廣度遍歷演算法獲得樹的葉子節點數目
@param[in] myTree 輸入字典儲存的樹結構
@return numLeafs 返回葉子節點數目
"""
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #if type(secondDict[key]).__name__=='dict':
        if isinstance(secondDict[key],dict) == True:
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs +=1
    return numLeafs

"""
@brief 獲得樹的層數 採用遞迴深度遍歷演算法獲得樹深度
@param[in] myTree 輸入字典儲存的樹結構
@return maxDepth 返回最大層深度
"""
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]#獲得資料字典中鍵值列表 並返回第一個值
    secondDict = myTree[firstStr]#獲取第一個鍵值的值
    for key in secondDict.keys():
        #if type(secondDict[key]).__name__ == 'dict':#判斷資料型別是否是字典型別
        if isinstance(secondDict[key],dict) == True:#判斷資料型別
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth :
            maxDepth = thisDepth
    return maxDepth

"""
@brief 在子父節點位置中間顯示一個文字資訊
@param[in] cntrPt 起點座標 子節點座標
@param[in] parentPt 結束座標 父節點座標
@param[in] 在中間位置顯示的字元
"""
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

"""
@brief 繪製樹
@param[in] myTree
@param[in] parentPt
@param[in] nodeTxt
"""
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)
    #depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if isinstance(secondDict[key],dict) == True :
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD

機器學習實戰書中例子,構建的決策樹繪製如下:

fr=open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age','prescript','astigmatic','tearRate']
lensesTree = trees.createTree(lenses,lensesLabels)
treePlotter.createPlot(lensesTree)

這裡寫圖片描述

4、決策樹總結

通過本章的實踐,程式碼大部分同書中程式碼一樣,實現id3演算法構建的決策樹。明顯可以看出,決策樹邏輯結構簡單,非常直觀的表達出預測種類。也對python語言熟悉了很多,特別是對字典資料型別的使用。確實解決了大部分程式設計問題,使得程式設計更為簡單。後續將對決策樹具體進行應用。應用例項主要是針對周老師《機器學習》書中決策樹章節的課後習題進行練習。更加深入的理解決策樹。