1. 程式人生 > >決策樹——機器學習實戰完整版(python 3)

決策樹——機器學習實戰完整版(python 3)

import matplotlib.pyplot as plt
# boxstyle是文字框型別 fc是邊框粗細 sawtooth是鋸齒形
'''xy是終點座標
xytext是起點座標
可能疑問:為什麼說是終點,但是卻是箭頭從這出發的?
解答:arrowstyle="<-" 看到沒有,這是個反向的箭頭'''
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")

#createPlot 主函式,呼叫即可畫出決策樹,其中呼叫登了剩下的所有的函式,inTree的形式必須為巢狀的決策樹

def createPlot():
    fig=plt.figure(1,facecolor='white') # 新建一個畫布,背景設定為白色的
    fig.clf()# 將畫圖清空
    createPlot.ax1=plt.subplot(111,frameon=False)# 設定一個多圖展示,但是設定多圖只有一個,
    # 但是設定引數是111,構建了一個1*1的模組,並操作物件指向第一個圖。
    plotNode('decision', (0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('leaf', (0.8,0.5),(0.3,0.7),leafNode)
    plt.show()

def plotNode(nodeTxt,centerPt,parentPt,nodeType):#plotNode函式有nodeTxt,centerPt, parentPt, nodeType這四個引數。
                                                 # nodeTxt是註釋的文字資訊。centerPt表示那個節點框的位置。
                                                 #  parentPt表示那個箭頭的起始位置(終點座標)。nodeType表示的是節點的型別,
                                                #  也就會用我們之前定義的全域性變數。#xytext是起點座標  #va="center",ha="center"是座標的水平中心和垂直中心
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
                            xytext=centerPt,textcoords='axes fraction',\
                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)#annotate是註釋的意思,
                                                                        # 也就是作為原來那個框的註釋,也是新增一些新的東西
                                                 #arrowprops=arrow_args是結點的顏色
def getNumLeafs(myTree):
    numLeafs=0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    #firstStr=myTree.keys()[0]# 找到輸入的第一個元素,第一個關鍵詞為劃分資料集類別的標籤
    secondDict=myTree[firstStr]# mytree經過第一個特徵值分類後的字典
    for key in secondDict.keys():#測試資料是否為字典形式
        if type(secondDict[key]).__name__=='dict':#  type(secondDict[key]).__name__輸出的是括號裡面的變數的型別,即判斷secondDict[key]對應的內容是否為字典型別
            numLeafs+=getNumLeafs(secondDict[key])
        else:  numLeafs+=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth=0
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    #firstStr=myTree.keys()[0]# 找到輸入的第一個元素,第一個關鍵詞為劃分資料集類別的標籤
    secondDict=myTree[firstStr]# mytree經過第一個特徵值分類後的字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth: maxDepth=thisDepth
    return maxDepth
def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]
def plotMidText(cntrPt,parentPt,txtString):#在座標點cntrPt和parentPt連線線上的中點,顯示文字txtString   #parentPt表示那個箭頭的起始位置(終點座標),cntrPt葉節點的位置,箭頭的終點

    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]#x軸座標
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]#y軸座標
    createPlot.ax1.text(xMid,yMid,txtString)#在(xMid, yMid)處顯示txtString
def plotTree(myTree,parentPt,nodeTxt):  # nodeTxt是註釋的文字資訊
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#cntrPt用來記錄當前要畫的樹的樹根的結點位置
    # plotTree.xOff和plotTree.yOff是用來追蹤已經繪製的節點位置,plotTree.totalW為這個數的寬度,葉節點數
    #cntrPt用來記錄當前要畫的樹的樹根的結點位置在plotTree函式中,它是這樣計算的
    # cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # numLeafs記錄當前的樹中葉子結點個數。我們希望樹根在這些所有葉子節點的中間
    # plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW 這裡的
    # 1.0 + numLeafs 需要拆開來理解,也就是
    # plotTree.xOff + float(numLeafs) / 2.0 / plotTree.totalW + 1.0 / 2.0 / plotTree.totalW
    # plotTree.xOff + 1 / 2 * float(numLeafs) / plotTree.totalW + 0.5 / plotTree.totalW
    # 因為xOff的初始值是 - 0.5 / plotTree.totalW ,是往左偏了0.5 / plotTree.tatalW的,
    # 這裡正好加回去。這樣cntrPt記錄的x座標正好是所有葉子結點的中心點'''
    plotMidText(cntrPt,parentPt,nodeTxt)#顯示節點
    plotNode(firstStr,cntrPt,parentPt,decisionNode)#firstStr為需要顯示的文字,cntrPt為文字的中心點,
                                                   # parentPt為箭頭指向文字的起始點,decisionNode為文字屬性
    secondDict=myTree[firstStr]#子樹
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD#totalD是這個數的深度,深度移下一層,初始值為1
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW#x座標平移一個單位
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)#畫葉節點
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))#顯示箭頭文字
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD#下移一層深度

def createPlot(inTree):#是主函式,呼叫了plotTree,plotTree又呼叫了其他的函式
    fig=plt.figure(1,facecolor='white')#建立一個畫布,背景為白色
    fig.clf()#畫布清空
    axprops=dict(xticks=[],ytichs=[])#定義橫縱座標軸,無內容
    createPlot.ax1 = plt.subplot(111, frameon=False)
    #createPlot.ax1=plt.subplot(111,frameon=False,**axprops)#去掉x、y軸,xticks=[],ytichs=[]無內容,          #**表示此引數是字典引數
    # ax1是函式createPlot的一個屬性,這個可以在函式裡面定義也可以在函式定義後加入也可以
    # createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #frameon表示是否繪製座標軸矩形,無座標軸,111代表1X1個圖,第一個
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW#如果葉子結點的座標是 1/totalW , 2/totalW, 3/totalW, …, 1 的話,
                                 # 就正好在寬度的最右邊,為了讓座標在寬度的中間,需要減去0.5 / totalW 。
                                 # 所以createPlot函式中,初始化 plotTree.xOff 的值為-0.5/plotTree.totalW。
                                # 這樣每次 xOff + 1/totalW ,正好是下1個結點的準確位置
    plotTree.yOff=1.0               #yOff的初始值為1,每向下遞迴一次,這個值減去 1 / totalD
    plotTree(inTree,(0.5,1.0),'')
    plt.show()