1. 程式人生 > >決策樹的構建、展示與決策

決策樹的構建、展示與決策

1. 概述

上一篇日誌中,我們介紹了兩個決策樹構建演算法 – ID3、C4.5:
決策樹的構建演算法 – ID3 與 C4.5 演算法
本篇日誌我們來看看如何使用這兩個演算法以及其他工具構建和展示我們的決策樹。

2. 使用 C4.5 構建決策樹

有了上一篇日誌中,我們介紹的 ID3 與 C4.5 演算法,遞迴進行計算,選出每一層當前的最佳特徵以及最佳特徵對應的最佳劃分特徵值,我們就可以構建出完整的決策樹了:

# 此處有圖片

流程圖非常清晰,上圖的基本思想是,對於數值型特徵,我們只分為左右兩分支,以防止子樹過多,同時也避免多種分發造成的系統複雜度過高,而對於字串描述性特徵,我們按照特徵取值個數來進行子樹劃分,因為通常來說,數值型特徵取值會非常多,而字串描述性特徵則不會。

2.1. python 程式碼實現

# -*- coding: UTF-8 -*-
# {{{
import operator
from math import log


def createDataSet():
    """
    建立資料集

    :return: 資料集與特徵集
    """
    dataSet = [[706, 'hot', 'sunny', 'high', 'false', 'no'],
               [706, 'hot', 'sunny', 'high', 'true', 'no'],
               [
706, 'hot', 'overcast', 'high', 'false', 'yes'], [709, 'cool', 'rain', 'normal', 'false', 'yes'], [710, 'cool', 'overcast', 'normal', 'true', 'yes'], [712, 'mild', 'sunny', 'high', 'false', 'no'], [714, 'cool', 'sunny', 'normal', 'false', 'yes'
], [715, 'mild', 'rain', 'normal', 'false', 'yes'], [720, 'mild', 'sunny', 'normal', 'true', 'yes'], [721, 'mild', 'overcast', 'high', 'true', 'yes'], [722, 'hot', 'overcast', 'normal', 'false', 'yes'], [723, 'mild', 'sunny', 'high', 'true', 'no'], [726, 'cool', 'sunny', 'normal', 'true', 'no'], [730, 'mild', 'sunny', 'high', 'false', 'yes']] labels = ['日期', '氣候', '天氣', '氣溫', '寒冷'] return dataSet, labels def classCount(dataSet): """ 獲取每個特徵出現的次數 :param dataSet: 資料集 :return: """ labelCount = {} for one in dataSet: if one[-1] not in labelCount.keys(): labelCount[one[-1]] = 0 labelCount[one[-1]] += 1 return labelCount def calcShannonEntropy(dataSet): """ 計算系統資訊熵 :param dataSet: 資料集 :return: """ labelCount = classCount(dataSet) numEntries = len(dataSet) Entropy = 0.0 for i in labelCount: prob = float(labelCount[i]) / numEntries Entropy -= prob * log(prob, 2) return Entropy def majorityClass(dataSet): """ 找到對應結果最多的特徵 :param dataSet: 資料集 :return: """ labelCount = classCount(dataSet) sortedLabelCount = sorted(labelCount.items(), key=operator.itemgetter(1), reverse=True) return sortedLabelCount[0][0] def splitDataSet(dataSet, i, value): """ 非數值型特徵劃分 將 dataset 以第 i 個特徵值為 value 作為基準劃分為多個部分 :param dataSet: 資料集 :param i: 特徵索引 :param value: 劃分基準值 :return: """ subDataSet = [] for one in dataSet: if one[i] == value: reduceData = one[:i] reduceData.extend(one[i + 1:]) subDataSet.append(reduceData) return subDataSet def splitContinuousDataSet(dataSet, i, value, direction): """ 數值型特徵劃分 將 dataset 以第 i 個特徵值為 value 作為基準劃分為多個部分 :param dataSet: 資料集 :param i: 特徵索引 :param value: 劃分基準值 :param direction: 0. 左側, 1. 右側 :return: """ subDataSet = [] for one in dataSet: if direction == 0: if one[i] > value: reduceData = one[:i] reduceData.extend(one[i + 1:]) subDataSet.append(reduceData) if direction == 1: if one[i] <= value: reduceData = one[:i] reduceData.extend(one[i + 1:]) subDataSet.append(reduceData) return subDataSet def chooseBestFeat(dataSet, labels): """ 獲取最佳特徵與特徵對應的最佳劃分值 :param dataSet: 資料集 :param labels: 特徵集 :return: """ global bestSplit """ 計算劃分前系統的資訊熵 """ baseEntropy = calcShannonEntropy(dataSet) bestFeat = 0 baseGainRatio = -1 numFeats = len(dataSet[0]) - 1 bestSplitDic = {} """ 遍歷每個特徵 """ for i in range(numFeats): """ 獲取該特徵所有值 """ featVals = [example[i] for example in dataSet] uniVals = sorted(set(featVals)) if type(featVals[0]).__name__ == 'float' or type(featVals[0]).__name__ == 'int': """ 用於區分的座標值 """ splitList = [] for j in range(len(uniVals) - 1): splitList.append((uniVals[j] + uniVals[j + 1]) / 2.0) """ 計算資訊增益比,找到最佳劃分屬性與劃分閾值 """ for j in range(len(splitList)): """ 該劃分情況下熵值 """ newEntropy = 0.0 splitInfo = 0.0 value = splitList[j] """ 劃分出左右兩側資料集 """ subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0) subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1) """ 計算劃分後系統資訊熵 """ prob0 = float(len(subDataSet0)) / len(dataSet) newEntropy -= prob0 * calcShannonEntropy(subDataSet0) prob1 = float(len(subDataSet1)) / len(dataSet) newEntropy -= prob1 * calcShannonEntropy(subDataSet1) """ 獲取懲罰引數 """ splitInfo -= prob0 * log(prob0, 2) splitInfo -= prob1 * log(prob1, 2) """ 計算資訊增益比 """ gainRatio = float(baseEntropy - newEntropy) / splitInfo if gainRatio > baseGainRatio: baseGainRatio = gainRatio bestSplit = j bestFeat = i bestSplitDic[labels[i]] = splitList[bestSplit] else: splitInfo = 0.0 newEntropy = 0.0 for value in uniVals: """ 劃分資料集 """ subDataSet = splitDataSet(dataSet, i, value) """ 計算劃分後系統資訊熵 """ prob = float(len(subDataSet)) / len(dataSet) newEntropy -= prob * calcShannonEntropy(subDataSet) """ 獲取懲罰引數 """ splitInfo -= prob * log(prob, 2) """ 計算資訊增益比 """ gainRatio = float(baseEntropy - newEntropy) / splitInfo if gainRatio > baseGainRatio: bestFeat = i baseGainRatio = gainRatio bestFeatValue = None if type(dataSet[0][bestFeat]).__name__ == 'float' or type(dataSet[0][bestFeat]).__name__ == 'int': bestFeatValue = bestSplitDic[labels[bestFeat]] if type(dataSet[0][bestFeat]).__name__ == 'str': bestFeatValue = labels[bestFeat] return bestFeat, bestFeatValue def createTree(dataSet, labels): """ 遞迴建立決策樹 :param dataSet: 資料集 :param labels: 特徵指標集 :return: 決策樹字典結構 """ classList = [example[-1] for example in dataSet] if len(set(classList)) == 1: return classList[0] if len(dataSet[0]) == 1: return majorityClass(dataSet) """ 找到當前的最佳劃分屬性與劃分閾值 """ bestFeat, bestFeatLabel = chooseBestFeat(dataSet, labels) myTree = {labels[bestFeat]: {}} subLabels = labels[:bestFeat] subLabels.extend(labels[bestFeat + 1:]) if type(dataSet[0][bestFeat]).__name__ == 'str': featVals = [example[bestFeat] for example in dataSet] uniqueVals = set(featVals) """ 遞迴建立左右子樹 """ for value in uniqueVals: """ 獲取去除該特徵資料集 """ reduceDataSet = splitDataSet(dataSet, bestFeat, value) myTree[labels[bestFeat]][value] = createTree(reduceDataSet, subLabels) if type(dataSet[0][bestFeat]).__name__ == 'int' or type(dataSet[0][bestFeat]).__name__ == 'float': value = bestFeatLabel """ 劃分資料集 """ greaterDataSet = splitContinuousDataSet(dataSet, bestFeat, value, 0) smallerDataSet = splitContinuousDataSet(dataSet, bestFeat, value, 1) """ 遞迴建立左右子樹 """ myTree[labels[bestFeat]]['>' + str(value)] = createTree(greaterDataSet, subLabels) myTree[labels[bestFeat]]['<=' + str(value)] = createTree(smallerDataSet, subLabels) return myTree if __name__ == '__main__': dataSet, labels = createDataSet() print(createTree(dataSet, labels)) #}}}

返回了:

{
  '日期': {
    '>728.0': 'yes',
    '<=728.0': {
      '寒冷': {
        'false': {
          '氣溫': {
            'high': {
              '氣候': {
                'hot': {
                  '天氣': {
                    'sunny': 'no',
                    'overcast': 'yes'
                  }
                },
                'mild': 'no'
              }
            },
            'normal': 'yes'
          }
        },
        'true': {
          '氣溫': {
            'high': {
              '氣候': {
                'hot': 'no',
                'mild': {
                  '天氣': {
                    'sunny': 'no',
                    'overcast': 'yes'
                  }
                }
              }
            },
            'normal': {
              '氣候': {
                'mild': 'yes',
                'cool': {
                  '天氣': {
                    'sunny': 'no',
                    'overcast': 'yes'
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}

3. 決策樹的視覺化

上面的 json 結果看上去非常不清楚,我們可不可以畫出決策樹的樹結構呢?
我們可以利用 matplotlib 模組來實現樹結構的繪製:

# -*- coding: UTF-8 -*-
# {{{
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties


def getNumLeafs(myTree):
    """
    獲取決策樹葉子結點的數目

    :param myTree: 決策樹
    :return: 決策樹的葉子結點的數目
    """
    numLeafs = 0  # 初始化葉子
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]  # 獲取下一組字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 測試該結點是否為字典,如果不是字典,代表此結點為葉子結點
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    獲取決策樹的層數

    :param myTree: 決策樹
    :return: 決策樹的層數
    """
    maxDepth = 0  # 初始化決策樹深度
    firstStr = next(iter(
        myTree))  # python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法獲取結點屬性,可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr]  # 獲取下一個字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 測試該結點是否為字典,如果不是字典,代表此結點為葉子結點
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth  # 更新層數
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    繪製節點

    :param nodeTxt: 結點名
    :param centerPt: 文字位置
    :param parentPt: 標註的箭頭位置
    :param nodeType: 結點格式
    :return:
    """
    arrow_args = dict(arrowstyle="<-")  # 定義箭頭格式
    font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)  # 設定中文字型
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',  # 繪製結點
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)


def plotMidText(cntrPt, parentPt, txtString):
    """
    標註有向邊屬性值

    :param cntrPt: 當前節點
    :param parentPt: 父節點
    :param txtString: 標註內容
    :return:
    """
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 計算標註位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    """
    繪製決策樹

    :param myTree: 決策數字典
    :param parentPt: 父節點
    :param nodeTxt: 節點名
    :return:
    """
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 設定結點格式
    leafNode = dict(boxstyle="round4", fc="0.8")  # 設定葉結點格式
    numLeafs = getNumLeafs(myTree)  # 獲取決策樹葉結點數目,決定了樹的寬度
    firstStr = next(iter(myTree))  # 下個字典
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)  # 中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)  # 標註有向邊屬性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 繪製結點
    secondDict = myTree[firstStr]  # 下一個字典,也就是繼續繪製子結點
    plotTree.yOff = plotTree.yOff - 1.0