1. 程式人生 > >機器學習演算法(2) 決策樹

機器學習演算法(2) 決策樹

基於決策樹的基本思想(ID3演算法),學習資訊增益的計算,決策樹的構建、使用、儲存。

例子來自《Machine Learning in Action》 Peter Harrington

熵值計算

計算資料集合中分類的數量與概率,根據公式求得熵。

from math import log

"""計算熵值"""
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}                            # 用於儲存分類標籤的種類和個數
    for featVec in
dataSet: currentLabel = featVec[-1] # 當前資料點的分類標籤 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * log(prob,2
) #以2為底求對數 return shannonEnt

測試資料

提供一個如下的資料集合,用過兩個特徵對生物是否屬於魚類進行確認。

不浮出水面是否可以生存 是否有腳蹼 屬於魚類
1
2
3
4
5
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0
, 1, 'no'], [0, 1, 'no'], labels = ['no surfacing','flippers'] return dataSet, labels

測試

def testShannonEnt():
    myDat,labels = createDataSet()
    print (calcShannonEnt(myDat))

結果

0.9709505944546686

當在資料集中再加入一種分類

"""建立測試資料集合"""
dataSet = [[1, 1, 'yes'],
           [1, 1, 'yes'],
           [1, 0, 'no'],
           [0, 1, 'no'],
           [0, 1, 'no'],
           [1, 1, 'maybe'],]

重新計算熵,可得結果:

1.4591479170272448

熵增大,即混亂度(不確定性)增大。值的變化符合熵的定義。

劃分資料集

以下程式碼包含三個輸入變數,具體含義見註釋。其中dataSet中所包含的資料點,每一個數據點都有多個特徵。axis表示接下來按照第幾個特徵進行劃分資料,value表示返回的資料集第axis特徵的特徵值等於多少。

"""劃分資料集"""
'''
dataSet:帶劃分資料集
axis:劃分資料集的特徵(第axis個,從零開始計數)
value:需要返回的特徵的值

'''
def splitDataSet(dataSet, axis, value):
    retDataSet = []             
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]    
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)          # 當前資料點 除去當前特徵後儲存
    return retDataSet

測試資料劃分

"""測試劃分資料集"""
def testSplitData():
    myDat,labels = trees.createDataSet()
    print(trees.splitDataSet(myDat,0,1))
    print(trees.splitDataSet(myDat,0,0))

結果

[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 'no'], [1, 'no']]

第一行的分類結果表示,按照第1個特徵對資料集進行劃分,返回的結果是第一個特徵值為1的資料點。第二行返回的是第一個特徵值等於0的資料點。劃分結果和預想的一致。

尋找最好的劃分方式

依照演算法,尋找資訊增益最大的分類方式作為最好的分類方式

"""尋找最好的劃分方式"""

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1         # 獲得特徵個數    
    baseEntropy = calcShannonEnt(dataSet)     # 原始的熵值
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):             # 對於每一個特徵都進行迭代
        featList = [example[i] for example in dataSet]   # 提取當前特徵在每個資料點中的值
        uniqueVals = set(featList)           #轉換為一個set集合(沒有重複元素)
        newEntropy = 0.0
        for value in uniqueVals:
            # 針對資料集合,對第i個特徵進行分類,返回值是特徵值為value的
            subDataSet = splitDataSet(dataSet, i, value)  
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy     
        if (infoGain > bestInfoGain):       # 比較每次分類資訊增益
            bestInfoGain = infoGain         # 如果大,就替換當前的值
            bestFeature = i
    return bestFeature  

測試

"""測試最好的劃分方式"""
def testChooseBestFeatureToSplit():
    myDat,labels = trees.createDataSet()
    print(trees.chooseBestFeatureToSplit(myDat))

結果

當前資料利用第0個特徵分類資訊增益最大。

0

遞迴構建決策樹

由遞迴構成樹停止的條件有兩個:
1. 所有的標籤的類都相同
2. 所有的特徵都用完了

具體實現見程式碼

"""建立樹"""
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]    # 分類標籤的值
    if classList.count(classList[0]) == len(classList):
        return classList[0]      # 所有的標籤的類都相同 返回這個類標籤
    if len(dataSet[0]) == 1:     # 如果所有的特徵都用完了,則停止
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 獲得資訊增益最大的分類特徵
    bestFeatLabel = labels[bestFeat]              # 獲得當前特徵的具體含義
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])   # 刪除已分類的特徵
    featValues = [example[bestFeat] for example in dataSet]   # 當前分類特徵下的資料點特徵值
    uniqueVals = set(featValues)     # 轉換為list型別
    for value in uniqueVals:
        subLabels = labels[:]       # 拷貝標籤
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

"""返回出現次數最多的分類名稱"""
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

測試

"""測試建立樹"""
def testCreateTree():
    myDat,labels = trees.createDataSet()
    myTree = trees.createTree(myDat,labels)
    print(myTree)

結果

以字典的形式返回決策樹

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

利用決策樹判斷新資料點

依照資料點的每一個特徵根據決策樹進行劃分,知道得出型別。

"""利用決策樹判斷新資料點"""
def classify(inputTree,featLabels,testVec):
    firstSides = list(inputTree.keys())   # 第一個分類特徵
    firstStr = firstSides[0]        #找到輸入的第一個元素
    secondDict = inputTree[firstStr]      # 二級字典
    featIndex = featLabels.index(firstStr)   # 當前特徵值在資料集的位置,返回時索引
    key = testVec[featIndex]             # 拿到新資料點的當前特徵的特徵值
    valueOfFeat = secondDict[key]        # 根據特徵值 劃分資料點
    if isinstance(valueOfFeat, dict):  # 如果不是葉節點,迭代;
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat     # 如果是葉節點,返回標籤類
    return classLabel

測試

"""測試決策樹判斷新資料點"""
def testClassify():
    myDat,labels = trees.createDataSet()
    myTree = trees.createTree(myDat,labels)
    myDat,labels = trees.createDataSet()
    print(trees.classify(myTree,labels,[1,0]))
    print(trees.classify(myTree,labels,[1,1]))

結果

返回判斷結果

no
yes

儲存決策樹

決策樹的建立比較耗時,為了方便一次建立多次使用。可以把建立的決策樹序列化,儲存到磁碟上,需要的時候再讀取使用。

"""序列化並寫入磁碟"""
def storeTree(inputTree,filename):
    fw = open(filename,'wb+')   # 要以二進位制格式開啟檔案
    pickle.dump(inputTree,fw)
    fw.close()

"""讀取磁碟並反序列化"""   
def grabTree(filename):
    fr = open(filename,'rb')    # 要以二進位制格式開啟檔案
    return pickle.load(fr)

測試

"""測試決策樹儲存"""   
def testStoreAndGrabTree():
    myDat,labels = trees.createDataSet()
    myTree = trees.createTree(myDat,labels)
    trees.storeTree(myTree,'trees.txt')
    reloadMyTree = trees.grabTree('trees.txt')
    print(reloadMyTree)  

結果

可以從磁碟得到之前的決策樹

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}