1. 程式人生 > >python3實現決策樹(機器學習實戰)

python3實現決策樹(機器學習實戰)


from math import log
def calcShannonEnt(dataSet):#計算給定資料集的夏農熵
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries print(prob) shannonEnt -= prob * log(prob, 2) return shannonEnt mydata = [[1, 1,'yes'], [1, 1,'yes'], [1, 0,'no'], [0, 1, 'no'], [0,1, 'no']] print(calcShannonEnt(mydata)) def splitDataSet
(dataSet, axis, value):
#dataSet 是資料集,axis是第幾個特徵,value是該特徵的取值。 retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet #該函式是將資料集中第axis個特徵的值為value的資料提取出來。
def chooseBestFeatureToSplit(dataSet):#選擇最好的特徵劃分 numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0 bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature mydata = [[1, 1,'yes'], [1, 1,'yes'], [1, 0,'no'], [0, 1, 'no'], [0,1, 'no']] print(chooseBestFeatureToSplit(mydata)); def majorityCnt(classList):#如果剩下的資料中無特徵,則直接按最大百分比形成葉節點 classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount += 1; sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgette(1), reverse = True) return sortedClassCount[0][0] def createTree(dataSet, labels):#建立決策樹 classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featvalue = [example[bestFeat] for example in dataSet] uniqueVals = set(featvalue) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree labels = ['no surface', 'flippers'] print(createTree(mydata, labels))