決策樹——機器學習實戰完整版(python 3)
阿新 • • 發佈:2019-01-07
import matplotlib.pyplot as plt # boxstyle是文字框型別 fc是邊框粗細 sawtooth是鋸齒形 '''xy是終點座標 xytext是起點座標 可能疑問:為什麼說是終點,但是卻是箭頭從這出發的? 解答:arrowstyle="<-" 看到沒有,這是個反向的箭頭''' decisionNode=dict(boxstyle="sawtooth",fc="0.8") leafNode=dict(boxstyle="round4",fc="0.8") arrow_args=dict(arrowstyle="<-") #createPlot 主函式,呼叫即可畫出決策樹,其中呼叫登了剩下的所有的函式,inTree的形式必須為巢狀的決策樹 def createPlot(): fig=plt.figure(1,facecolor='white') # 新建一個畫布,背景設定為白色的 fig.clf()# 將畫圖清空 createPlot.ax1=plt.subplot(111,frameon=False)# 設定一個多圖展示,但是設定多圖只有一個, # 但是設定引數是111,構建了一個1*1的模組,並操作物件指向第一個圖。 plotNode('decision', (0.5,0.1),(0.1,0.5),decisionNode) plotNode('leaf', (0.8,0.5),(0.3,0.7),leafNode) plt.show() def plotNode(nodeTxt,centerPt,parentPt,nodeType):#plotNode函式有nodeTxt,centerPt, parentPt, nodeType這四個引數。 # nodeTxt是註釋的文字資訊。centerPt表示那個節點框的位置。 # parentPt表示那個箭頭的起始位置(終點座標)。nodeType表示的是節點的型別, # 也就會用我們之前定義的全域性變數。#xytext是起點座標 #va="center",ha="center"是座標的水平中心和垂直中心 createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\ xytext=centerPt,textcoords='axes fraction',\ va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)#annotate是註釋的意思, # 也就是作為原來那個框的註釋,也是新增一些新的東西 #arrowprops=arrow_args是結點的顏色 def getNumLeafs(myTree): numLeafs=0 firstSides = list(myTree.keys()) firstStr = firstSides[0] #firstStr=myTree.keys()[0]# 找到輸入的第一個元素,第一個關鍵詞為劃分資料集類別的標籤 secondDict=myTree[firstStr]# mytree經過第一個特徵值分類後的字典 for key in secondDict.keys():#測試資料是否為字典形式 if type(secondDict[key]).__name__=='dict':# type(secondDict[key]).__name__輸出的是括號裡面的變數的型別,即判斷secondDict[key]對應的內容是否為字典型別 numLeafs+=getNumLeafs(secondDict[key]) else: numLeafs+=1 return numLeafs def getTreeDepth(myTree): maxDepth=0 firstSides = list(myTree.keys()) firstStr = firstSides[0] #firstStr=myTree.keys()[0]# 找到輸入的第一個元素,第一個關鍵詞為劃分資料集類別的標籤 secondDict=myTree[firstStr]# mytree經過第一個特徵值分類後的字典 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth=1+getTreeDepth(secondDict[key]) else: thisDepth=1 if thisDepth>maxDepth: maxDepth=thisDepth return maxDepth def retrieveTree(i): listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\ {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}] return listOfTrees[i] def plotMidText(cntrPt,parentPt,txtString):#在座標點cntrPt和parentPt連線線上的中點,顯示文字txtString #parentPt表示那個箭頭的起始位置(終點座標),cntrPt葉節點的位置,箭頭的終點 xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]#x軸座標 yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]#y軸座標 createPlot.ax1.text(xMid,yMid,txtString)#在(xMid, yMid)處顯示txtString def plotTree(myTree,parentPt,nodeTxt): # nodeTxt是註釋的文字資訊 numLeafs=getNumLeafs(myTree) depth=getTreeDepth(myTree) firstSides = list(myTree.keys()) firstStr = firstSides[0] cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#cntrPt用來記錄當前要畫的樹的樹根的結點位置 # plotTree.xOff和plotTree.yOff是用來追蹤已經繪製的節點位置,plotTree.totalW為這個數的寬度,葉節點數 #cntrPt用來記錄當前要畫的樹的樹根的結點位置在plotTree函式中,它是這樣計算的 # cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) # numLeafs記錄當前的樹中葉子結點個數。我們希望樹根在這些所有葉子節點的中間 # plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW 這裡的 # 1.0 + numLeafs 需要拆開來理解,也就是 # plotTree.xOff + float(numLeafs) / 2.0 / plotTree.totalW + 1.0 / 2.0 / plotTree.totalW # plotTree.xOff + 1 / 2 * float(numLeafs) / plotTree.totalW + 0.5 / plotTree.totalW # 因為xOff的初始值是 - 0.5 / plotTree.totalW ,是往左偏了0.5 / plotTree.tatalW的, # 這裡正好加回去。這樣cntrPt記錄的x座標正好是所有葉子結點的中心點''' plotMidText(cntrPt,parentPt,nodeTxt)#顯示節點 plotNode(firstStr,cntrPt,parentPt,decisionNode)#firstStr為需要顯示的文字,cntrPt為文字的中心點, # parentPt為箭頭指向文字的起始點,decisionNode為文字屬性 secondDict=myTree[firstStr]#子樹 plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD#totalD是這個數的深度,深度移下一層,初始值為1 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW#x座標平移一個單位 plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)#畫葉節點 plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))#顯示箭頭文字 plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD#下移一層深度 def createPlot(inTree):#是主函式,呼叫了plotTree,plotTree又呼叫了其他的函式 fig=plt.figure(1,facecolor='white')#建立一個畫布,背景為白色 fig.clf()#畫布清空 axprops=dict(xticks=[],ytichs=[])#定義橫縱座標軸,無內容 createPlot.ax1 = plt.subplot(111, frameon=False) #createPlot.ax1=plt.subplot(111,frameon=False,**axprops)#去掉x、y軸,xticks=[],ytichs=[]無內容, #**表示此引數是字典引數 # ax1是函式createPlot的一個屬性,這個可以在函式裡面定義也可以在函式定義後加入也可以 # createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #frameon表示是否繪製座標軸矩形,無座標軸,111代表1X1個圖,第一個 plotTree.totalW=float(getNumLeafs(inTree)) plotTree.totalD=float(getTreeDepth(inTree)) plotTree.xOff=-0.5/plotTree.totalW#如果葉子結點的座標是 1/totalW , 2/totalW, 3/totalW, …, 1 的話, # 就正好在寬度的最右邊,為了讓座標在寬度的中間,需要減去0.5 / totalW 。 # 所以createPlot函式中,初始化 plotTree.xOff 的值為-0.5/plotTree.totalW。 # 這樣每次 xOff + 1/totalW ,正好是下1個結點的準確位置 plotTree.yOff=1.0 #yOff的初始值為1,每向下遞迴一次,這個值減去 1 / totalD plotTree(inTree,(0.5,1.0),'') plt.show()