1. 程式人生 > >ID3決策樹中連續值的處理+周志華《機器學習》圖4.8和圖4.10繪製

ID3決策樹中連續值的處理+周志華《機器學習》圖4.8和圖4.10繪製

轉載自
https://blog.csdn.net/Leafage_M/article/details/80137305

用一句話總結這篇部落格的內容就是:
對於當前n條資料,相鄰求平均值,得到n-1個分割值,要點如下:
①連續數值特徵的熵計算就是對上面的n-1個分割值不停嘗試,
嘗試得到最佳分割值,利用分割值兩側的資料來計算條件熵
進而最終計算最大熵增益.
②如果當前同時存在離散值和連續值特徵,那麼連續值取最大資訊增益熵,來和離散值特徵進行比較,然後選擇最佳分割特徵.
③如果當前只剩下連續值特徵,那麼每次分割都選擇讓熵增益最大的分割值作為劃分特徵.

所以也印證了周志華<機器學習>上面的一段話,
決策樹中,
離散數值特徵只能用一次,
連續數值特徵能使用多次.

轉載的連結中python3.0的,修改為python2.7如下:
top.py

#-*- coding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import collections
from math import log
import operator
import treePlotter
import pandas as pd

def createDataSet():
    """
    西瓜資料集3.0
    :return:
    """
    dataSet = [
        # 1
        ['青綠', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.697, 0.460, '好瓜'],
        # 2
        ['烏黑', '蜷縮', '沉悶', '清晰', '凹陷', '硬滑', 0.774, 0.376, '好瓜'],
        # 3
        ['烏黑', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.634, 0.264, '好瓜'],
        # 4
        ['青綠', '蜷縮', '沉悶', '清晰', '凹陷', '硬滑', 0.608, 0.318, '好瓜'],
        # 5
        ['淺白', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.556, 0.215, '好瓜'],#######3
        # 6
        ['青綠', '稍蜷', '濁響', '清晰', '稍凹', '軟粘', 0.403, 0.237, '好瓜'],
        # 7
        ['烏黑', '稍蜷', '濁響', '稍糊', '稍凹', '軟粘', 0.481, 0.149, '好瓜'],
        # 8
        ['烏黑', '稍蜷', '濁響', '清晰', '稍凹', '硬滑', 0.437, 0.211, '好瓜'],

        # ----------------------------------------------------
        # 9
        ['烏黑', '稍蜷', '沉悶', '稍糊', '稍凹', '硬滑', 0.666, 0.091, '壞瓜'],
        # 10
        ['青綠', '硬挺', '清脆', '清晰', '平坦', '軟粘', 0.243, 0.267, '壞瓜'],
        # 11
        ['淺白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 0.245, 0.057, '壞瓜'],##############
        # 12
        ['淺白', '蜷縮', '濁響', '模糊', '平坦', '軟粘', 0.343, 0.099, '壞瓜'],###########
        # 13
        ['青綠', '稍蜷', '濁響', '稍糊', '凹陷', '硬滑', 0.639, 0.161, '壞瓜'],
        # 14
        ['淺白', '稍蜷', '沉悶', '稍糊', '凹陷', '硬滑', 0.657, 0.198, '壞瓜'],###########
        # 15
        ['烏黑', '稍蜷', '濁響', '清晰', '稍凹', '軟粘', 0.360, 0.370, '壞瓜'],
        # 16
        ['淺白', '蜷縮', '濁響', '模糊', '平坦', '硬滑', 0.593, 0.042, '壞瓜'],###########3
        # 17
        ['青綠', '蜷縮', '沉悶', '稍糊', '稍凹', '硬滑', 0.719, 0.103, '壞瓜']
    ]



#下面是西瓜資料集3.0a

    # dataSet = [
    #     # 1
    #     [0.697, 0.460, '好瓜'],
    #     # 2
    #     [0.774, 0.376, '好瓜'],
    #     # 3
    #     [0.634, 0.264, '好瓜'],
    #     # 4
    #     [0.608, 0.318, '好瓜'],
    #     # 5
    #     [0.556, 0.215, '好瓜'],
    #     # 6
    #     [0.403, 0.237, '好瓜'],
    #     # 7
    #     [0.481, 0.149, '好瓜'],
    #     # 8
    #     [0.437, 0.211, '好瓜'],
    
    #     # ----------------------------------------------------
    #     # 9
    #     [0.666, 0.091, '壞瓜'],
    #     # 10
    #     [0.243, 0.267, '壞瓜'],
    #     # 11
    #     [0.245, 0.057, '壞瓜'],
    #     # 12
    #     [ 0.343, 0.099, '壞瓜'],
    #     # 13
    #     [ 0.639, 0.161, '壞瓜'],
    #     # 14
    #     [0.657, 0.198, '壞瓜'],
    #     # 15
    #     [0.360, 0.370, '壞瓜'],
    #     # 16
    #     [0.593, 0.042, '壞瓜'],
    #     # 17
    #     [ 0.719, 0.103, '壞瓜']
    # ]

    # 西瓜資料集3.0特徵列表
    labels = ['色澤', '根蒂', '敲擊', '紋理', '臍部', '觸感', '密度', '含糖率']
    # 西瓜資料集3.0a特徵列表
    # labels = ['密度', '含糖率']


    # 特徵對應的所有可能的情況
    labels_full = {}

    for i in range(len(labels)):
        labelList = [example[i] for example in dataSet]
        uniqueLabel = set(labelList)
        labels_full[labels[i]] = uniqueLabel
    print("--------------------------------------")
    for item in labels_full:
        print("item=",unicode(item))
    print("--------------------------------------")
    print("len(labels_full)=",len(labels_full))
    print("len(labels)=",len(labels))

    return dataSet, labels, labels_full


def calcShannonEnt(dataSet):
    """
    計算給定資料集的資訊熵(夏農熵)
    :param dataSet:
    :return:
    """
    # 計算出資料集的總數
    numEntries = len(dataSet)

    # 用來統計標籤
    labelCounts = collections.defaultdict(int)

    # 迴圈整個資料集,得到資料的分類標籤
    for featVec in dataSet:
        # 得到當前的標籤
        currentLabel = featVec[-1]

        # 將對應的標籤值加一
        labelCounts[currentLabel] += 1

    # 預設的資訊熵
    shannonEnt = 0.0

    for key in labelCounts:
        # 計算出當前分類標籤佔總標籤的比例數
        prob = float(labelCounts[key]) / numEntries

        # 以2為底求對數
        shannonEnt -= prob * log(prob, 2)

    return shannonEnt


def splitDataSetForSeries(dataSet, axis, value):
    print("進入splitDataSetForSeries,axis=",axis)
    """
    按照給定的數值,將資料集分為不大於和大於兩部分
    :param dataSet: 要劃分的資料集
    :param i: 特徵值所在的下標
    :param value: 劃分值
    :return:
    """
    # 用來儲存不大於劃分值的集合
    eltDataSet = []
    # 用來儲存大於劃分值的集合
    gtDataSet = []
    # 進行劃分,保留該特徵值
    print("axis=",axis)
    for feat in dataSet:
        if feat[axis] <= value:
            eltDataSet.append(feat)
        else:
            gtDataSet.append(feat)

    return eltDataSet, gtDataSet


def splitDataSet(dataSet, axis, value):
    """
    按照給定的特徵值,將資料集劃分
    :param dataSet: 資料集
    :param axis: 給定特徵值的座標
    :param value: 給定特徵值滿足的條件,只有給定特徵值等於這個value的時候才會返回
    :return:
    """
    # 建立一個新的列表,防止對原來的列表進行修改
    retDataSet = []

    # 遍歷整個資料集
    for featVec in dataSet:
        # 如果給定特徵值等於想要的特徵值
        if featVec[axis] == value:
            # 將該特徵值前面的內容儲存起來
            reducedFeatVec = featVec[:axis]
            # 將該特徵值後面的內容儲存起來,所以將給定特徵值給去掉了
            reducedFeatVec.extend(featVec[axis + 1:])
            # 新增到返回列表中
            retDataSet.append(reducedFeatVec)

    return retDataSet




#這個函式是在尋找最佳分割點,使得熵增益最大.
def calcInfoGainForSeries(dataSet, i, baseEntropy):
    print("進入calcInfoGainForSeries,i=",i)
    """
    計算連續值的資訊增益
    :param dataSet:整個資料集
    :param i: 對應的特徵值下標
    :param baseEntropy: 基礎資訊熵
    :return: 返回一個資訊增益值,和當前的劃分點
    """

    # 記錄最大的資訊增益
    maxInfoGain = 0.0

    # 最好的劃分點
    bestMid = -1

    # 得到資料集中所有的當前特徵值列表
    featList = [example[i] for example in dataSet]

    # 得到分類列表
    classList = [example[-1] for example in dataSet]

    dictList = dict(zip(featList, classList))

    # 將其從小到大排序,按照連續值的大小排列
    sortedFeatList = sorted(dictList.items(), key=operator.itemgetter(0))

    # 計算連續值有多少個
    numberForFeatList = len(sortedFeatList)

    # midFeatList = [round((sortedFeatList[i][0] + sortedFeatList[i+1][0])/2.0, 3)for i in range(numberForFeatList - 1)]
    midFeatList = [round((sortedFeatList[k][0] + sortedFeatList[k+1][0])/2.0, 3)for k in range(numberForFeatList - 1)]
    #上面一句程式碼注意:
    # 由於作者在這裡使用的是python3.x的語法,所以原有程式碼中列表推導式中的i會干擾calcInfoGainForSeries(dataSet, i, baseEntropy)中的i
    #所以為了避免python直譯器混淆,上面的i->k

    # 計算出各個劃分點資訊增益
    for mid in midFeatList:
        # 將連續值劃分為不大於當前劃分點和大於當前劃分點兩部分
        eltDataSet, gtDataSet = splitDataSetForSeries(dataSet, i, mid)

        # 計算兩部分的特徵值熵和權重的乘積之和
        newEntropy = float(len(eltDataSet))/float(len(sortedFeatList))*float(calcShannonEnt(eltDataSet)) + float(len(gtDataSet))/float(len(sortedFeatList))*float(calcShannonEnt(gtDataSet))

        # 計算出資訊增益
        infoGain = baseEntropy - newEntropy
        # print('當前劃分值為:' + str(mid) + ',此時的資訊增益為:' + str(infoGain))
        if infoGain > maxInfoGain:
            bestMid = mid
            maxInfoGain = infoGain

    return maxInfoGain, bestMid


def calcInfoGain(dataSet ,featList, i, baseEntropy):
    """
    計算資訊增益
    :param dataSet: 資料集
    :param featList: 當前特徵列表
    :param i: 當前特徵值下標
    :param baseEntropy: 基礎資訊熵
    :return:
    """
    # 將當前特徵唯一化,也就是說當前特徵值中共有多少種
    uniqueVals = set(featList)

    # 新的熵,代表當前特徵值的熵
    newEntropy = 0.0

    # 遍歷現在有的特徵的可能性
    for value in uniqueVals:
        # 在全部資料集的當前特徵位置上,找到該特徵值等於當前值的集合
        subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)
        # 計算出權重
        prob = float(len(subDataSet)) / float(len(dataSet))
        # 計算出當前特徵值的熵
        newEntropy += prob * calcShannonEnt(subDataSet)

    # 計算出“資訊增益”
    infoGain = baseEntropy - newEntropy

    return infoGain


def chooseBestFeatureToSplit(dataSet, labels):
    """
    選擇最好的資料集劃分特徵,根據資訊增益值來計算,可處理連續值
    :param dataSet:
    :return:
    """
    # 得到資料的特徵值總數
    numFeatures = len(dataSet[0]) - 1

    # 計算出基礎資訊熵
    baseEntropy = calcShannonEnt(dataSet)

    # 基礎資訊增益為0.0
    bestInfoGain = 0.0

    # 最好的特徵值
    bestFeature = -1

    # 標記當前最好的特徵值是不是連續值
    flagSeries = 0

    # 如果是連續值的話,用來記錄連續值的劃分點
    bestSeriesMid = 0.0

    # 對每個特徵值進行求資訊熵
    for i in range(numFeatures):
        print("i=",i)

        # 得到資料集中所有的當前特徵值列表
        featList = [example[i] for example in dataSet]

        if isinstance(featList[0], str):
            infoGain = calcInfoGain(dataSet, featList, i, baseEntropy)
        else:
            # print('當前劃分屬性為:' + str(labels[i]))
            infoGain, bestMid = calcInfoGainForSeries(dataSet, i, baseEntropy)

        # print('當前特徵值為:' + labels[i] + ',對應的資訊增益值為:' + str(infoGain))

        # 如果當前的資訊增益比原來的大
        if infoGain > bestInfoGain:
            # 最好的資訊增益
            bestInfoGain = infoGain
            # 新的最好的用來劃分的特徵值
            bestFeature = i

            flagSeries = 0
            if not isinstance(dataSet[0][bestFeature], str):
                flagSeries = 1
                bestSeriesMid = bestMid

    # print('資訊增益最大的特徵為:' + labels[bestFeature])
    if flagSeries:
        return bestFeature, bestSeriesMid
    else:
        return bestFeature


def majorityCnt(classList):
    """
    找到次數最多的類別標籤
    :param classList:
    :return:
    """
    # 用來統計標籤的票數
    classCount = collections.defaultdict(int)

    # 遍歷所有的標籤類別
    for vote in classList:
        classCount[vote] += 1

    # 從大到小排序
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    # 返回次數最多的標籤
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    """
    建立決策樹
    :param dataSet: 資料集
    :param labels: 特徵標籤
    :return:
    """
    # 拿到所有資料集的分類標籤
    classList = [example[-1] for example in dataSet]

    # 統計第一個標籤出現的次數,與總標籤個數比較,如果相等則說明當前列表中全部都是一種標籤,此時停止劃分
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    # 計算第一行有多少個數據,如果只有一個的話說明所有的特徵屬性都遍歷完了,剩下的一個就是類別標籤
    if len(dataSet[0]) == 1:
        # 返回剩下標籤中出現次數較多的那個
        return majorityCnt(classList)

    # 選擇最好的劃分特徵,得到該特徵的下標
    bestFeat = chooseBestFeatureToSplit(dataSet=dataSet, labels=labels)

    # 得到最好特徵的名稱
    bestFeatLabel = ''

    # 記錄此刻是連續值還是離散值,1連續,2離散
    flagSeries = 0

    # 如果是連續值,記錄連續值的劃分點
    midSeries = 0.0

    # 如果是元組的話,說明此時是連續值
    if isinstance(bestFeat, tuple):
        # 重新修改分叉點資訊
        bestFeatLabel = str(labels[bestFeat[0]]) + '小於' + str(bestFeat[1]) + '?'
        # 得到當前的劃分點
        midSeries = bestFeat[1]
        # 得到下標值
        bestFeat = bestFeat[0]
        # 連續值標誌
        flagSeries = 1
    else:
        # 得到分叉點資訊
        bestFeatLabel = labels[bestFeat]
        # 離散值標誌
        flagSeries = 0

    # 使用一個字典來儲存樹結構,分叉處為劃分的特徵名稱
    myTree = {bestFeatLabel: {}}

    # 得到當前特徵標籤的所有可能值
    featValues = [example[bestFeat] for example in dataSet]

    # 連續值處理
    if flagSeries:
        # 將連續值劃分為不大於當前劃分點和大於當前劃分點兩部分
        eltDataSet, gtDataSet = splitDataSetForSeries(dataSet, bestFeat, midSeries)
        # 得到剩下的特徵標籤
        subLabels = labels[:]
        # 遞迴處理小於劃分點的子樹
        subTree = createTree(eltDataSet, subLabels)
        myTree[bestFeatLabel]['小於'] = subTree

        # 遞迴處理大於當前劃分點的子樹
        subTree = createTree(gtDataSet, subLabels)
        myTree[bestFeatLabel]['大於'] = subTree

        return myTree

    # 離散值處理
    else:
        # 將本次劃分的特徵值從列表中刪除掉
        del (labels[bestFeat])
        # 唯一化,去掉重複的特徵值
        uniqueVals = set(featValues)
        # 遍歷所有的特徵值
        for value in uniqueVals:
            # 得到剩下的特徵標籤
            subLabels = labels[:]
            # 遞迴呼叫,將資料集中該特徵等於當前特徵值的所有資料劃分到當前節點下,遞迴呼叫時需要先將當前的特徵去除掉
            subTree = createTree(splitDataSet(dataSet=dataSet, axis=bestFeat, value=value), subLabels)
            # 將子樹歸到分叉處下
            myTree[bestFeatLabel][value] = subTree
        return myTree


if __name__ == '__main__':
    dataSet, labels, labels_full = createDataSet()
    myTree = createTree(dataSet, labels)
    print(myTree)
    treePlotter.createPlot(myTree)

treePlotter.py

#-*- coding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')

import matplotlib.pyplot as plt
from pylab import mpl 
mpl.rcParams["font.sans-serif"] = ["SimHei"]

import matplotlib
from matplotlib.font_manager import *  
import matplotlib.pyplot as plt 
plt.rcParams['axes.unicode_minus']=False
import numpy as np
import pandas as pd
from numpy import *
#首先確保自己系統中安裝了下面兩種字型,下面的這句程式碼經過測試,目前直接在修改matplotlibrc
matplotlib.rcParams['font.sans-serif'] = 'HYQuanTangShiF,Times New Roman'#中文除外的設定成New Roman,中文設定成漢儀全唐詩體繁
plt.rcParams['axes.unicode_minus'] = False


decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


#返回葉子數量
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

#返回樹的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    nodeTxt=unicode(nodeTxt)
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, unicode(txtString), va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    nodeTxt=unicode(nodeTxt)
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            print("plotNode=",plotNode)
            print("type(plotNode)=",type(plotNode))
            print("leafNode=",leafNode)
            print("type(leafNode)=",type(leafNode))
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#createPlot(thisTree)

西瓜資料集3.0(資料集在程式碼中自帶)
用來繪製書上的圖4.8
在周志華<機器學習>第85頁
在這裡插入圖片描述

西瓜資料集3.0a(資料集在程式碼中自帶)
用來繪製書上的圖4.10,
在周志華<機器學習>第90頁

在這裡插入圖片描述