1. 程式人生 > >西瓜書 課後習題4.3 基於資訊熵決策樹,連續和離散屬性,並驗證模型

西瓜書 課後習題4.3 基於資訊熵決策樹,連續和離散屬性,並驗證模型

import matplotlib.pyplot as plt
import numpy as np
from math import log
import operator
import csv


def readDataset(filename):
    '''
    讀取資料
    :param filename: 資料檔名,CSV格式
    :return:  以列表形式返回資料列表和特徵列表
    '''
    with open(filename) as f:
        reader = csv.reader(f)
        header_row = next(reader)
        labels = header_row[1:9]
        dataset = []
        for line in reader:
            tempVect = line[1:10]
            dataset.append(tempVect)
    return dataset, labels


def infoEnt(dataset):
    '''
    計算資訊熵
    :param dataset:  輸入資料集
    :return:  返回資訊熵
    '''
    numdata = len(dataset)
    labels = {}
    for featVec in dataset:
        label = featVec[-1]
        if label not in labels.keys():
            labels[label] = 0
        labels[label] += 1
    infoEnt = 0
    for lab in labels.keys():
        prop = float(labels[lab]) / numdata
        infoEnt -= (prop * log(prop, 2))
    return infoEnt


def bestFeatureSplit(dataset):
    '''
    最優屬性劃分
    :param dataset: 輸入需要劃分的資料集
    :return:  返回最優劃分屬性的下標
    '''
    numFeature = len(dataset[0]) - 1
    baseInfoEnt = infoEnt(dataset)
    bestInfoGain = 0
    bestFeature = -1
    bestSplitPoint = None
    continuous = False
    for i in range(numFeature):
        featList = [example[i] for example in dataset]
        newEnt = 0
        if all(c in "0123456789.-" for c in featList[0]):  # 連續屬性
            continuous = True
            featList.sort()
            tempFeatList = [float(feat) for feat in featList]  # 字串轉換成數字,用set(featList)會出現結果不穩定
            mediumPoints = []
            for index in range(len(tempFeatList) - 1):
                mediumPoints.append((tempFeatList[index] + tempFeatList[index + 1]) / 2)
            for point in mediumPoints:
                for part in range(2):
                    subDataset = splitDataset(dataset, i, point, continuous, part)
                    prop = len(subDataset) / float(len(dataset))
                    newEnt += prop * infoEnt(subDataset)
                infoGain = baseInfoEnt - newEnt
                if (infoGain > bestInfoGain):
                    bestInfoGain = infoGain
                    bestFeature = i
                    bestSplitPoint = point
        else:
            uniqueValue = set(featList)
            for value in uniqueValue:
                subDataset = splitDataset(dataset, i, value, continuous)
                prop = len(subDataset) / float(len(dataset))
                newEnt += prop * infoEnt(subDataset)
            infoGain = baseInfoEnt - newEnt
            if (infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
                bestSplitPoint = None
    return bestFeature, bestSplitPoint


def splitDataset(dataset, axis, value, continuous, part=0):
    '''
    對某個特徵進行劃分後的資料集
    :param dataset: 資料集
    :param axis: 劃分屬性的下標
    :param value: 劃分屬性值
    :return: 返回剩餘資料集
    '''
    restDataset = []
    if continuous == True:  # 連續屬性
        for featVec in dataset:
            if part == 0 and float(featVec[axis]) <= value:
                restFeatVec = featVec[:axis]
                restFeatVec.extend(featVec[axis + 1:])
                restDataset.append(restFeatVec)
            if part == 1 and float(featVec[axis]) > value:
                restFeatVec = featVec[:axis]
                restFeatVec.extend(featVec[axis + 1:])
                restDataset.append(restFeatVec)
    else:  # 離散屬性
        for featVec in dataset:
            if featVec[axis] == value:
                restFeatVec = featVec[:axis]
                restFeatVec.extend(featVec[axis + 1:])
                restDataset.append(restFeatVec)
    return restDataset


def majorClass(classList):
    '''
    對葉節點的分類結果進行數量投票劃分
    :param classList:  葉節點上的樣本數量
    :return: 返回葉節點劃分結果
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount:
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 返回陣列
    return sortedClassCount[0][0]


def createTree(dataset, labels, datasetFull, labelsFull):
    '''
    遞迴建立決策樹
    :param dataset: 資料集列表
    :param labels:  標籤集列表
    :param datasetFull: 資料集列表,再傳一次
    :param labelsFull:  標籤集列表,再傳一次
    :return: 返回決策樹字典
    '''
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataset[0]) == 1:
        return (majorClass(classList))
    bestFeat, bestSplitPoint = bestFeatureSplit(dataset)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    # 建立所有屬性標籤的所有值,以防漏掉某些取值,例如西瓜資料集2.0中的  色澤:淺白
    bestFeatIndex = labelsFull.index(bestFeatLabel)
    featValuesFull = [example[bestFeatIndex] for example in datasetFull]
    uniqueValFull = set(featValuesFull)
    if bestSplitPoint == None:  # 離散節點
        del (labels[bestFeat])
        featValues = [example[bestFeat] for example in dataset]
        uniqueVal = set(featValues)
        if uniqueVal == uniqueValFull:
            for value in uniqueVal:
                subLabels = labels[:]  # 遞歸回退過程需要繼續使用標籤,所以前行過程標籤副本
                myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value, False),
                                                          subLabels, datasetFull, labelsFull)
        else:
            for value in uniqueVal:
                subLabels = labels[:]  # 遞歸回退過程需要繼續使用標籤,所以前行過程標籤副本
                myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value, False),
                                                          subLabels, datasetFull, labelsFull)
                uniqueValFull.remove(value)
            for value in uniqueValFull:
                myTree[bestFeatLabel][value] = majorClass(classList)
    else:  # 連續節點
        subLabels = labels[:]
        myTree[bestFeatLabel]['<=' + str(bestSplitPoint)] = createTree(
            splitDataset(dataset, bestFeat, bestSplitPoint, True, 0), subLabels, datasetFull, labelsFull)
        subLabels = labels[:]
        myTree[bestFeatLabel]['>' + str(bestSplitPoint)] = createTree(
            splitDataset(dataset, bestFeat, bestSplitPoint, True, 1), subLabels, datasetFull, labelsFull)
    return myTree


def decideTreePredict(decideTree, featList, testData):
    '''
    決策樹預測
    :param decideTree: 決策樹模型
    :param featList: 特徵列表
    :param testData: 測試資料
    :return: 返回預測結果
    '''
    firstFeat = list(decideTree.keys())[0]
    secDict = decideTree[firstFeat]
    featIndex = featList.index(firstFeat)
    decideLabel = None
    for key in secDict.keys():
        if key[0] == '<':
            value = float(key[2:])
            if float(testData[featIndex]) <= value:
                if type(secDict[key]).__name__ == 'dict':
                    decideLabel = decideTreePredict(secDict[key], featList, testData)
                else:
                    decideLabel = secDict[key]
        elif key[0] == '>':
            value = float(key[1:])
            if float(testData[featIndex]) > value:
                if type(secDict[key]).__name__ == 'dict':
                    decideLabel = decideTreePredict(secDict[key], featList, testData)
                else:
                    decideLabel = secDict[key]

        else:
            if testData[featIndex] == key:
                if type(secDict[key]).__name__ == 'dict':
                    decideLabel = decideTreePredict(secDict[key], featList, testData)
                else:
                    decideLabel = secDict[key]
    return decideLabel


if __name__ == '__main__':
    filename = 'C:\\Users\\14399\\Desktop\\西瓜3.0.csv'
    dataset, labels = readDataset(filename)
    datasetFull = dataset[:]
    labelsFull = labels[:]
    myTree = (createTree(dataset, labels, datasetFull, labelsFull))
    print(myTree)
    # 驗證結果,這裡用的原來訓練集資料,所以為100%正確
    count = 0
    for testData in dataset:
        if decideTreePredict(myTree, labelsFull, testData) == testData[-1]:
            count += 1
    print(count)

生成結果:{'紋理': {'模糊': '否', '清晰': {'根蒂': {'硬挺': '否', '蜷縮': '是', '稍蜷': {'密度': {'<=0.3815': '否', '>0.3815': '是'}}}}, '稍糊': {'觸感': {'軟粘': '是', '硬滑': '否'}}}}    ( 與書中結果略有不同,但不影響正確率。)

西瓜3.0資料集:連結:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA         密碼:3h6n

參考: 

https://blog.csdn.net/u014514939/article/details/79299619 (含畫樹演算法)

         https://blog.csdn.net/csqazwsxedc/article/details/65697652

         https://blog.csdn.net/icefire_tyh/article/details/54575527