西瓜書 課後習題4.3 基於資訊熵決策樹,連續和離散屬性,並驗證模型
阿新 • • 發佈:2018-11-27
import matplotlib.pyplot as plt import numpy as np from math import log import operator import csv def readDataset(filename): ''' 讀取資料 :param filename: 資料檔名,CSV格式 :return: 以列表形式返回資料列表和特徵列表 ''' with open(filename) as f: reader = csv.reader(f) header_row = next(reader) labels = header_row[1:9] dataset = [] for line in reader: tempVect = line[1:10] dataset.append(tempVect) return dataset, labels def infoEnt(dataset): ''' 計算資訊熵 :param dataset: 輸入資料集 :return: 返回資訊熵 ''' numdata = len(dataset) labels = {} for featVec in dataset: label = featVec[-1] if label not in labels.keys(): labels[label] = 0 labels[label] += 1 infoEnt = 0 for lab in labels.keys(): prop = float(labels[lab]) / numdata infoEnt -= (prop * log(prop, 2)) return infoEnt def bestFeatureSplit(dataset): ''' 最優屬性劃分 :param dataset: 輸入需要劃分的資料集 :return: 返回最優劃分屬性的下標 ''' numFeature = len(dataset[0]) - 1 baseInfoEnt = infoEnt(dataset) bestInfoGain = 0 bestFeature = -1 bestSplitPoint = None continuous = False for i in range(numFeature): featList = [example[i] for example in dataset] newEnt = 0 if all(c in "0123456789.-" for c in featList[0]): # 連續屬性 continuous = True featList.sort() tempFeatList = [float(feat) for feat in featList] # 字串轉換成數字,用set(featList)會出現結果不穩定 mediumPoints = [] for index in range(len(tempFeatList) - 1): mediumPoints.append((tempFeatList[index] + tempFeatList[index + 1]) / 2) for point in mediumPoints: for part in range(2): subDataset = splitDataset(dataset, i, point, continuous, part) prop = len(subDataset) / float(len(dataset)) newEnt += prop * infoEnt(subDataset) infoGain = baseInfoEnt - newEnt if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i bestSplitPoint = point else: uniqueValue = set(featList) for value in uniqueValue: subDataset = splitDataset(dataset, i, value, continuous) prop = len(subDataset) / float(len(dataset)) newEnt += prop * infoEnt(subDataset) infoGain = baseInfoEnt - newEnt if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i bestSplitPoint = None return bestFeature, bestSplitPoint def splitDataset(dataset, axis, value, continuous, part=0): ''' 對某個特徵進行劃分後的資料集 :param dataset: 資料集 :param axis: 劃分屬性的下標 :param value: 劃分屬性值 :return: 返回剩餘資料集 ''' restDataset = [] if continuous == True: # 連續屬性 for featVec in dataset: if part == 0 and float(featVec[axis]) <= value: restFeatVec = featVec[:axis] restFeatVec.extend(featVec[axis + 1:]) restDataset.append(restFeatVec) if part == 1 and float(featVec[axis]) > value: restFeatVec = featVec[:axis] restFeatVec.extend(featVec[axis + 1:]) restDataset.append(restFeatVec) else: # 離散屬性 for featVec in dataset: if featVec[axis] == value: restFeatVec = featVec[:axis] restFeatVec.extend(featVec[axis + 1:]) restDataset.append(restFeatVec) return restDataset def majorClass(classList): ''' 對葉節點的分類結果進行數量投票劃分 :param classList: 葉節點上的樣本數量 :return: 返回葉節點劃分結果 ''' classCount = {} for vote in classList: if vote not in classCount: classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 返回陣列 return sortedClassCount[0][0] def createTree(dataset, labels, datasetFull, labelsFull): ''' 遞迴建立決策樹 :param dataset: 資料集列表 :param labels: 標籤集列表 :param datasetFull: 資料集列表,再傳一次 :param labelsFull: 標籤集列表,再傳一次 :return: 返回決策樹字典 ''' classList = [example[-1] for example in dataset] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataset[0]) == 1: return (majorClass(classList)) bestFeat, bestSplitPoint = bestFeatureSplit(dataset) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel: {}} # 建立所有屬性標籤的所有值,以防漏掉某些取值,例如西瓜資料集2.0中的 色澤:淺白 bestFeatIndex = labelsFull.index(bestFeatLabel) featValuesFull = [example[bestFeatIndex] for example in datasetFull] uniqueValFull = set(featValuesFull) if bestSplitPoint == None: # 離散節點 del (labels[bestFeat]) featValues = [example[bestFeat] for example in dataset] uniqueVal = set(featValues) if uniqueVal == uniqueValFull: for value in uniqueVal: subLabels = labels[:] # 遞歸回退過程需要繼續使用標籤,所以前行過程標籤副本 myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value, False), subLabels, datasetFull, labelsFull) else: for value in uniqueVal: subLabels = labels[:] # 遞歸回退過程需要繼續使用標籤,所以前行過程標籤副本 myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value, False), subLabels, datasetFull, labelsFull) uniqueValFull.remove(value) for value in uniqueValFull: myTree[bestFeatLabel][value] = majorClass(classList) else: # 連續節點 subLabels = labels[:] myTree[bestFeatLabel]['<=' + str(bestSplitPoint)] = createTree( splitDataset(dataset, bestFeat, bestSplitPoint, True, 0), subLabels, datasetFull, labelsFull) subLabels = labels[:] myTree[bestFeatLabel]['>' + str(bestSplitPoint)] = createTree( splitDataset(dataset, bestFeat, bestSplitPoint, True, 1), subLabels, datasetFull, labelsFull) return myTree def decideTreePredict(decideTree, featList, testData): ''' 決策樹預測 :param decideTree: 決策樹模型 :param featList: 特徵列表 :param testData: 測試資料 :return: 返回預測結果 ''' firstFeat = list(decideTree.keys())[0] secDict = decideTree[firstFeat] featIndex = featList.index(firstFeat) decideLabel = None for key in secDict.keys(): if key[0] == '<': value = float(key[2:]) if float(testData[featIndex]) <= value: if type(secDict[key]).__name__ == 'dict': decideLabel = decideTreePredict(secDict[key], featList, testData) else: decideLabel = secDict[key] elif key[0] == '>': value = float(key[1:]) if float(testData[featIndex]) > value: if type(secDict[key]).__name__ == 'dict': decideLabel = decideTreePredict(secDict[key], featList, testData) else: decideLabel = secDict[key] else: if testData[featIndex] == key: if type(secDict[key]).__name__ == 'dict': decideLabel = decideTreePredict(secDict[key], featList, testData) else: decideLabel = secDict[key] return decideLabel if __name__ == '__main__': filename = 'C:\\Users\\14399\\Desktop\\西瓜3.0.csv' dataset, labels = readDataset(filename) datasetFull = dataset[:] labelsFull = labels[:] myTree = (createTree(dataset, labels, datasetFull, labelsFull)) print(myTree) # 驗證結果,這裡用的原來訓練集資料,所以為100%正確 count = 0 for testData in dataset: if decideTreePredict(myTree, labelsFull, testData) == testData[-1]: count += 1 print(count)
生成結果:{'紋理': {'模糊': '否', '清晰': {'根蒂': {'硬挺': '否', '蜷縮': '是', '稍蜷': {'密度': {'<=0.3815': '否', '>0.3815': '是'}}}}, '稍糊': {'觸感': {'軟粘': '是', '硬滑': '否'}}}} ( 與書中結果略有不同,但不影響正確率。)
西瓜3.0資料集:連結:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA 密碼:3h6n
參考: https://blog.csdn.net/u014514939/article/details/79299619 (含畫樹演算法)