1. 程式人生 > >CART之迴歸樹python程式碼實現

CART之迴歸樹python程式碼實現

一、CART ( Classification And Regression Tree) 分類迴歸樹

1、基尼指數:

在分類問題中,假設有K 個類,樣本點屬於第k 類的概率為Pk ,則概率分佈的基尼指數定義為:
Gini(P)=k=1KPk(1Pk)=1k=1KPk2

在CART 分類問題中,基尼指數作為特徵選擇的依據:選擇基尼指數最小的特徵及切分點做為最優特徵和最優切分點。

2、在迴歸問題中,特徵選擇及最佳劃分特徵值的依據是:劃分後樣本的均方差之和最小!

二、演算法分析:

CART 主要包括特徵選擇、迴歸樹的生成、剪枝三部分

資料特徵停止劃分的條件:
1、當前資料集中的標籤相同,返回當前的標籤
2、劃分前後的總方差差距很小,資料不劃分,返回的屬性為空,返回的最佳劃分值為當前所有標籤的均值。
3、劃分後的左右兩個資料集的樣本數量較小,返回的屬性為空,返回的最佳劃分值為當前所有標籤的均值。

若滿足上述三個特徵停止劃分的條件,則返回的最佳特徵為空,返回的最佳劃分特徵值會作為葉子結點。

注:CART是一棵二叉樹。 在生成CART迴歸樹過程中,一個特徵可能會被使用不止一次,所以,不存在當前屬性集為空的情況;

1、特徵選擇(依據:總方差最小)

輸入:資料集、op = [m,n]
輸出:最佳特徵、最佳劃分特徵值

m表示剪枝前總方差與剪枝後總方差差值的最小值; n: 資料集劃分為左右兩個子資料集後,子資料集中的樣本的最少數量;

1、判斷資料集中所有的樣本標籤是否相同,是:返回當前標籤;
2、遍歷所有的樣本特徵,遍歷每一個特徵的特徵值。計算出每一個特徵值下的資料總方差,找出使總方差最小的特徵、特徵值
3、比較劃分前和劃分後的總方差大小;若劃分後總方差減少較小,則返回的最佳特徵為空,返回的最佳劃分特徵值會為當前資料集標籤的平均值。
4、比較劃分後的左右分支資料集樣本中的數量,若某一分支資料集中樣本少於指定數量op[1],則返回的最佳特徵為空,
返回的最佳劃分特徵值會為當前資料集標籤的平均值。
5、否則,返回使總方差最小的特徵、特徵值

二、迴歸樹的生成函式 createTree
輸入:資料集
輸出:生成迴歸樹
1、得到當前資料集的最佳劃分特徵、最佳劃分特徵值
2、若返回的最佳特徵為空,則返回最佳劃分特徵值(作為葉子節點)
3、宣告一個字典,用於儲存當前的最佳劃分特徵、最佳劃分特徵值
4、執行二元切分;根據最佳劃分特徵、最佳劃分特徵值,將當前的資料劃分為兩部分
5、在左子樹中呼叫createTree 函式, 在右子樹呼叫createTree 函式。
6、返回樹。

注:在生成的迴歸樹模型中,劃分特徵、特徵值、左節點、右節點均有相應的關鍵詞對應。

三、(後)剪枝:(CART 樹一定是二叉樹,所以,如果發生剪枝,肯定是將兩個葉子節點合併)

輸入:樹、測試集
輸出:樹

1、判斷測試集是否為空,是:對樹進行塌陷處理
2、判斷樹的左右分支是否為樹結構,是:根據樹當前的特徵值、劃分值將測試集分為Lset、Rset兩個集合;
3、判斷樹的左分支是否是樹結構:是:在該子集遞迴呼叫剪枝過程;
4、判斷樹的右分支是否是樹結構:是:在該子集遞迴呼叫剪枝過程;
5、判斷當前樹結構的兩個節點是否為葉子節點:
是:
a、根據當前樹結構,測試集劃分為Lset,Rset兩部分;
b、計算沒有合併時的總方差NoMergeError,即:測試集在Lset 和 Rset 的總方差之和;
c、合併後,取葉子節點值為原左右葉子結點的均值。求取測試集在該節點處的總方差MergeError,;
d、比較合併前後總方差的大小;若NoMergeError > MergeError,返回合併後的節點;否則,返回原來的樹結構;
否:
返回樹結構。

程式碼實現:資料集

#-*- coding:utf-8 -*-
from numpy import *
import numpy as np
# 三大步驟:
'''
1、特徵的選擇:標準:總方差最小
2、迴歸樹的生成:停止劃分的標準
3、剪枝:
'''

# 匯入資料集
def loadData(filaName):
    dataSet = []
    fr = open(filaName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        theLine = map(float, curLine)                 # map all elements to float()
        dataSet.append(theLine)
    return dataSet

# 特徵選擇:輸入:       輸出:最佳特徵、最佳劃分值
'''
1、選擇標準
遍歷所有的特徵Fi:遍歷每個特徵的所有特徵值Zi;找到Zi,劃分後總的方差最小
停止劃分的條件:
1、當前資料集中的標籤相同,返回當前的標籤
2、劃分前後的總方差差距很小,資料不劃分,返回的屬性為空,返回的最佳劃分值為當前所有標籤的均值。
3、劃分後的左右兩個資料集的樣本數量較小,返回的屬性為空,返回的最佳劃分值為當前所有標籤的均值。
當劃分的資料集滿足上述條件之一,返回的最佳劃分值作為葉子節點;
當劃分後的資料集不滿足上述要求時,找到最佳劃分的屬性,及最佳劃分特徵值
'''

# 計算總的方差
def GetAllVar(dataSet):
    return var(dataSet[:,-1])*shape(dataSet)[0]

# 根據給定的特徵、特徵值劃分資料集
def dataSplit(dataSet,feature,featNumber):
    dataL =  dataSet[nonzero(dataSet[:,feature] > featNumber)[0],:]
    dataR = dataSet[nonzero(dataSet[:,feature] <= featNumber)[0],:]
    return dataL,dataR

# 特徵劃分
def choseBestFeature(dataSet,op = [1,4]):          # 三個停止條件可否當作是三個預剪枝操作
    if len(set(dataSet[:,-1].T.tolist()[0]))==1:     # 停止條件 1
        regLeaf = mean(dataSet[:,-1])         
        return None,regLeaf                   # 返回標籤的均值作為葉子節點
    Serror = GetAllVar(dataSet)
    BestFeature = -1; BestNumber = 0; lowError = inf
    m,n = shape(dataSet) # m 個樣本, n -1 個特徵
    for i in range(n-1):    # 遍歷每一個特徵值
        for j in set(dataSet[:,i].T.tolist()[0]):
            dataL,dataR = dataSplit(dataSet,i,j)
            if shape(dataR)[0]<op[1] or shape(dataL)[0]<op[1]: continue  # 如果所給的劃分後的資料集中樣本數目甚少,則直接跳出
            tempError = GetAllVar(dataL) + GetAllVar(dataR)
            if tempError < lowError:
                lowError = tempError; BestFeature = i; BestNumber = j
    if Serror - lowError < op[0]:               # 停止條件 2   如果所給的資料劃分前後的差別不大,則停止劃分
        return None,mean(dataSet[:,-1])         
    dataL, dataR = dataSplit(dataSet, BestFeature, BestNumber)
    if shape(dataR)[0] < op[1] or shape(dataL)[0] < op[1]:        # 停止條件 3
        return None, mean(dataSet[:, -1])
    return BestFeature,BestNumber


# 決策樹生成
def createTree(dataSet,op=[1,4]):
    bestFeat,bestNumber = choseBestFeature(dataSet,op)
    if bestFeat==None: return bestNumber
    regTree = {}
    regTree['spInd'] = bestFeat
    regTree['spVal'] = bestNumber
    dataL,dataR = dataSplit(dataSet,bestFeat,bestNumber)
    regTree['left'] = createTree(dataL,op)
    regTree['right'] = createTree(dataR,op)
    return  regTree

# 後剪枝操作
# 用於判斷所給的節點是否是葉子節點
def isTree(Tree):
    return (type(Tree).__name__=='dict' )

# 計算兩個葉子節點的均值
def getMean(Tree):
    if isTree(Tree['left']): Tree['left'] = getMean(Tree['left'])
    if isTree(Tree['right']):Tree['right'] = getMean(Tree['right'])
    return (Tree['left']+ Tree['right'])/2.0

# 後剪枝
def pruneTree(Tree,testData):
    if shape(testData)[0]==0: return getMean(Tree)
    if isTree(Tree['left'])or isTree(Tree['right']):
        dataL,dataR = dataSplit(testData,Tree['spInd'],Tree['spVal'])
    if isTree(Tree['left']):
        Tree['left'] = pruneTree(Tree['left'],dataL)
    if isTree(Tree['right']):
        Tree['right'] = pruneTree(Tree['right'],dataR)
    if not isTree(Tree['left']) and not isTree(Tree['right']):
        dataL,dataR = dataSplit(testData,Tree['spInd'],Tree['spVal'])
        errorNoMerge = sum(power(dataL[:,-1] - Tree['left'],2)) + sum(power(dataR[:,-1] - Tree['right'],2))
        leafMean = getMean(Tree)
        errorMerge = sum(power(testData[:,-1]-  leafMean,2))
        if errorNoMerge > errorMerge:
            print"the leaf merge"
            return leafMean
        else:
            return Tree
    else:
        return Tree

# 預測
def forecastSample(Tree,testData):
    if not isTree(Tree): return float(tree)
    # print"選擇的特徵是:" ,Tree['spInd']
    # print"測試資料的特徵值是:" ,testData[Tree['spInd']]
    if testData[0,Tree['spInd']]>Tree['spVal']:
        if isTree(Tree['left']):
            return forecastSample(Tree['left'],testData)
        else:
            return float(Tree['left'])
    else:
        if isTree(Tree['right']):
            return forecastSample(Tree['right'],testData)
        else:
            return float(Tree['right'])

def TreeForecast(Tree,testData):
    m = shape(testData)[0]
    y_hat = mat(zeros((m,1)))
    for i in range(m):
        y_hat[i,0] = forecastSample(Tree,testData[i])
    return y_hat

if __name__=="__main__":
    print "hello world"
    dataMat = loadData("ex2.txt")
    dataMat = mat(dataMat)
    op = [1,6]    # 引數1:剪枝前總方差與剪枝後總方差差值的最小值;引數2:將資料集劃分為兩個子資料集後,子資料集中的樣本的最少數量;        
    theCreateTree =  createTree(dataMat,op)
   # 測試資料
    dataMat2 = loadData("ex2test.txt")
    dataMat2 = mat(dataMat2)
    #thePruneTree =  pruneTree(theCreateTree, dataMat2)
    #print"剪枝後的後樹:\n",thePruneTree
    y = dataMat2[:, -1]
    y_hat = TreeForecast(theCreateTree,dataMat2)
    print corrcoef(y_hat,y,rowvar=0)[0,1]              # 用預測值與真實值計算相關係數

參考:
《機器學習實戰》
《統計學習方法》 李航