1. 程式人生 > >決策樹--從原理到實現

決策樹--從原理到實現

選擇 方法 入門 我們 sta ... dex 什麽 sin

一.引入

決策樹基本上是每一本機器學習入門書籍必講的東西,其決策過程和平時我們的思維很相似,所以非常好理解,同時有一堆信息論的東西在裏面,也算是一個入門應用,決策樹也有回歸和分類,但一般來說我們主要講的是分類

其實,個人感覺,決策樹是從一些數據量中提取特征,按照特征的顯著由強到弱來排列。常見應用為:回答一些問題,猜出你心裏想的是什麽?

為什麽第一個問題,永遠都是男還是女?為什麽?看完這個就知道了

二.代碼

  1 from math import log
  2 import operator
  3 
  4 def createDataSet():
  5     dataSet = [[1, 1, 
yes], 6 [1, 1, yes], 7 [1, 0, no], 8 [0, 1, no], 9 [0, 1, no]] 10 labels = [no surfacing,flippers] 11 #change to discrete values 12 return dataSet, labels 13 14 def calcShannonEnt(dataSet): 15 numEntries = len(dataSet)
16 labelCounts = {} 17 for featVec in dataSet: #the the number of unique elements and their occurance 18 currentLabel = featVec[-1] 19 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 20 labelCounts[currentLabel] += 1 21 shannonEnt = 0.0 22
for key in labelCounts: 23 prob = float(labelCounts[key])/numEntries 24 shannonEnt -= prob * log(prob,2) #log base 2 25 return shannonEnt 26 27 def splitDataSet(dataSet, axis, value): 28 retDataSet = [] 29 for featVec in dataSet: 30 if featVec[axis] == value: 31 reducedFeatVec = featVec[:axis] #chop out axis used for splitting 32 reducedFeatVec.extend(featVec[axis+1:]) 33 retDataSet.append(reducedFeatVec) 34 return retDataSet 35 36 def chooseBestFeatureToSplit(dataSet): 37 numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels 38 baseEntropy = calcShannonEnt(dataSet) 39 bestInfoGain = 0.0; bestFeature = -1 40 for i in range(numFeatures): #iterate over all the features 41 featList = [example[i] for example in dataSet]#create a list of all the examples of this feature 42 uniqueVals = set(featList) #get a set of unique values 43 newEntropy = 0.0 44 for value in uniqueVals: 45 subDataSet = splitDataSet(dataSet, i, value) 46 prob = len(subDataSet)/float(len(dataSet)) 47 newEntropy += prob * calcShannonEnt(subDataSet) 48 infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy 49 if (infoGain > bestInfoGain): #compare this to the best gain so far 50 bestInfoGain = infoGain #if better than current best, set to best 51 bestFeature = i 52 return bestFeature #returns an integer 53 54 def majorityCnt(classList): 55 classCount={} 56 for vote in classList: 57 if vote not in classCount.keys(): classCount[vote] = 0 58 classCount[vote] += 1 59 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 60 return sortedClassCount[0][0] 61 62 def createTree(dataSet,labels): 63 classList = [example[-1] for example in dataSet] 64 if classList.count(classList[0]) == len(classList): 65 return classList[0]#stop splitting when all of the classes are equal 66 if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet 67 return majorityCnt(classList) 68 bestFeat = chooseBestFeatureToSplit(dataSet) 69 bestFeatLabel = labels[bestFeat] 70 myTree = {bestFeatLabel:{}} 71 del(labels[bestFeat]) 72 featValues = [example[bestFeat] for example in dataSet] 73 uniqueVals = set(featValues) 74 for value in uniqueVals: 75 subLabels = labels[:] #copy all of labels, so trees don‘t mess up existing labels 76 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) 77 return myTree 78 79 def classify(inputTree,featLabels,testVec): 80 firstStr = inputTree.keys()[0] 81 secondDict = inputTree[firstStr] 82 featIndex = featLabels.index(firstStr) 83 key = testVec[featIndex] 84 valueOfFeat = secondDict[key] 85 if isinstance(valueOfFeat, dict): 86 classLabel = classify(valueOfFeat, featLabels, testVec) 87 else: classLabel = valueOfFeat 88 return classLabel 89 90 def storeTree(inputTree,filename): 91 import pickle 92 fw = open(filename,w) 93 pickle.dump(inputTree,fw) 94 fw.close() 95 96 def grabTree(filename): 97 import pickle 98 fr = open(filename) 99 return pickle.load(fr) 100

三.算法詳解

?信息增益

傳入數據集,得到該數據集的增益

 1 def calcShannonEnt(dataSet):
 2     numEntries = len(dataSet)
 3     labelCounts = {}
 4     for featVec in dataSet: #the the number of unique elements and their occurance
 5         currentLabel = featVec[-1]
 6         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
 7         labelCounts[currentLabel] += 1
 8     shannonEnt = 0.0
 9     for key in labelCounts:
10         prob = float(labelCounts[key])/numEntries
11         shannonEnt -= prob * log(prob,2) #log base 2
12     return shannonEnt

得到信息熵後,我們按照獲取最大信息增益的方法劃分數據集就行了

eg.運行下面的數據集

          [[1, 1, ‘yes‘],
[1, 1, ‘yes‘],
[1, 0, ‘no‘],
[0, 1, ‘no‘],
[0, 1, ‘no‘]]

labelCounts是一個map結構
currentLabel  labelCounts[currentLabel]   prob
yes        2                0.4
no         3                0.6

用信息論就可以得到0.4*log(-0.4)+0,6*log(-0.6)=0.971

?劃分數據集

  ※按照給定特征劃分數據集

  傳入數據集,第axis個(從0開始)特征,該特征的值

  輸出根據該數據集劃分得到的子數據集

1 def splitDataSet(dataSet, axis, value):
2     retDataSet = []
3     for featVec in dataSet:
4         if featVec[axis] == value:
5             reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
6             reducedFeatVec.extend(featVec[axis+1:])
7             retDataSet.append(reducedFeatVec)
8     return retDataSet
 eg.  myDat為
      [[1, 1, ‘yes‘],
[1, 1, ‘yes‘],
[1, 0, ‘no‘],
[0, 1, ‘no‘],
[0, 1, ‘no‘]]
傳入(myDat,0,1),輸出

[[1, ‘yes‘],[1, ‘yes‘], [0, ‘no‘]]

  ※選擇最好的數據集劃分方式

  傳入數據集

  輸出該數據集下按不同特征值排列得到信息熵變化最大的該特征值

 1 def chooseBestFeatureToSplit(dataSet):
 2     numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
 3     baseEntropy = calcShannonEnt(dataSet)
 4     bestInfoGain = 0.0; bestFeature = -1
 5     for i in range(numFeatures):        #iterate over all the features
 6         featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
 7         uniqueVals = set(featList)       #get a set of unique values
 8         newEntropy = 0.0
 9         for value in uniqueVals:
10             subDataSet = splitDataSet(dataSet, i, value)
11             prob = len(subDataSet)/float(len(dataSet))
12             newEntropy += prob * calcShannonEnt(subDataSet)     
13         infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
14         if (infoGain > bestInfoGain):       #compare this to the best gain so far
15             bestInfoGain = infoGain         #if better than current best, set to best
16             bestFeature = i
17     return bestFeature                      #returns an integer
 eg.  myDat為
      [[1, 1, ‘yes‘],
[1, 1, ‘yes‘],
[1, 0, ‘no‘],
[0, 1, ‘no‘],
[0, 1, ‘no‘]]
傳入(myDat)

第一次就是按第一個特征,值為1劃分
     按第一個特征,值為0劃分
     得到該情況下的信息熵
第二次就是按第二個特征,值為1劃分
     按第二個特征,值為0劃分
     得到該情況下的信息熵
......
選取信息熵最大時候的特征
  

?遞歸構建決策樹

1 def majorityCnt(classList):
2     classCount={}
3     for vote in classList:
4         if vote not in classCount.keys(): classCount[vote] = 0
5         classCount[vote] += 1
6     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
7     return sortedClassCount[0][0]

O(∩_∩)O~創建樹啦

 1 def createTree(dataSet,labels):
 2     classList = [example[-1] for example in dataSet]
 3     if classList.count(classList[0]) == len(classList): 
 4         return classList[0]#stop splitting when all of the classes are equal
 5     if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
 6         return majorityCnt(classList)
 7     bestFeat = chooseBestFeatureToSplit(dataSet)
 8     bestFeatLabel = labels[bestFeat]
 9     myTree = {bestFeatLabel:{}}
10     del(labels[bestFeat])
11     featValues = [example[bestFeat] for example in dataSet]
12     uniqueVals = set(featValues)
13     for value in uniqueVals:
14         subLabels = labels[:]       #copy all of labels, so trees don‘t mess up existing labels
15         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
16     return myTree  

O(∩_∩)O~~使用樹來決策了

 1 def classify(inputTree,featLabels,testVec):
 2     firstStr = inputTree.keys()[0]
 3     secondDict = inputTree[firstStr]
 4     featIndex = featLabels.index(firstStr)
 5     key = testVec[featIndex]
 6     valueOfFeat = secondDict[key]
 7     if isinstance(valueOfFeat, dict): 
 8         classLabel = classify(valueOfFeat, featLabels, testVec)
 9     else: classLabel = valueOfFeat
10     return classLabel

決策樹--從原理到實現