1. 程式人生 > >決策樹(二)分析與實踐

決策樹(二)分析與實踐

目錄

1 分析

1.1 背景:

1.2 定義

1.3 原理:

CART如何選擇分裂的屬性?

如何進行樹的剪枝來防止過擬合

對於含有空值的資料,此時應該怎麼構建樹。

2.實踐:(《機器學習實戰》第九章程式碼解析)

CART演算法的實現(運用到預剪枝)

後剪枝演算法實現

參考


1 分析

1.1 背景:

       線性迴歸的模型一般都要擬合所有的樣本點,但當資料擁有眾多特徵,並且特徵之間的關係十分的複雜,這時候往往是非線性的問題,很難構建全域性模型。
       方法:將資料集切分成很多份易建模的的資料,再線性迴歸(就像微分一樣的思想),一次切分不行就兩次不斷遞迴,這時候用樹結構就很合適。
        樹結構為什麼不用ID3演算法呢?因為ID3一般用來處理離散值,不能直接連續型特徵。所以我們採用二元切分法來處理連續型特徵,如果特徵值大於定值就走左子樹,小於就走右子樹。

1.2 定義

        基於上述問題,有了CART,分類迴歸樹,顧名思義既能分類又能迴歸,當CART是分類樹時,採用GINI值作為節點分裂的依據;當CART是迴歸樹時,採用樣本的最小方差作為節點分裂的依據;採用了二元切分,所以是一棵二叉樹。
       分類樹的作用是通過一個物件的特徵來預測該物件所屬的類別(打標籤),而回歸樹的目的是根據一個物件的資訊預測該物件的屬性數值。舉個栗子:一個數據集,包含以下資訊:看電視時間,
婚姻情況已婚/未婚),職業年齡;如果我們想預測一個人是否已婚,那麼構建的CART將是分類樹;如果想預測一個人的年齡,那麼構建的將是迴歸樹。

 

1.3 原理:

  • CART如何選擇分裂的屬性?

分裂的目的是為了能夠讓資料變純,使決策樹輸出的結果更接近真實值。那麼CART是如何評價節點的純度呢?如果是分類樹,CART採用GINI值衡量節點純度;如果是迴歸樹,採用樣本方差衡量節點純度。節點越不純,節點分類或者預測的效果就越差。

分類樹:用基尼指數來選擇最優特徵,同時決定該特徵的最優二值切分點

基尼指數公式和含義:

Gini(p)=\sum_{k=1}^{K}p_k(1-p_k)=1-\sum_{k=1}^{K}p_k^{2},對於二分類簡單的有Gini(p)=2p(1-p)

如果樣本點屬於和不屬於第k類的概率p_k1-p_k相差越近,比如五五開的話,基尼指數越大,結點純度越低。

對樣本集使用該公式:

對給定樣本集合D,其基尼指數為:Gini(D)=1-\sum_{k=1}^{K}\left ( \frac{\left | C_k \right |}{\left | D \right |} \right )^{2}

在特徵A下,集合D的基尼指數:Gini(D,A)=\frac{\left | D_1 \right |}{\left | D \right |}Gini(D_1)+\frac{\left | D_2 \right |}{\left | D \right |}Gini(D_2)

根據該基尼指數來確定最小的特徵及其對應切分點 

迴歸樹:用迴歸方差,方差越大,表示該節點的資料越分散,預測的效果就越差。如果一個節點的所有資料都相同,那麼方差就為0,此時可以很肯定得認為該節點的輸出值;如果節點的資料相差很大,那麼輸出的值有很大的可能與實際值相差較大。

因此,無論是分類樹還是迴歸樹,CART都要選擇使子節點的GINI值或者回歸方差最小的屬性作為分裂的方案。

  • 如何進行樹的剪枝來防止過擬合

從樣本預留出一部分資料用作“驗證集”以進行效能評估。

預剪枝

在決策樹生成過程中,對每個結點在劃分前先進性估計分,若當前節點的劃分不能帶來決策樹泛化能力提升,則停止劃分,將當前節點標記為葉結點;

後剪枝

現從訓練集生成完整的決策樹,然後自底向上地對非葉節點進行考察,弱將該節點對應的子樹替代能帶來決策樹泛化能力提升,則把當前子樹替換為葉子結點。

CART採用CCP(代價複雜度)剪枝方法。代價複雜度選擇節點表面誤差率增益值最小的非葉子節點,刪除該非葉子節點的左右子節點,若有多個非葉子節點的表面誤差率增益值相同小,則選擇非葉子節點中子節點數最多的非葉子節點進行剪枝。

  • 對於含有空值的資料,此時應該怎麼構建樹。

第一,如何在屬性值缺失的情況下進行屬性劃分選擇選擇劃分的特徵

不將缺失值的樣本代入選擇判斷的公式計算(資訊增益、增益率、基尼指數)之中,只在計算完後乘以一個有值的樣本比例即可。比如訓練集有10個樣本,在屬性 a 上,有兩個樣本缺失值,那麼計算該屬性劃分的資訊增益時,我們可以忽略這兩個缺失值的樣本來計算資訊增益,然後在計算結果上乘以8/10即可。

第二,若一個樣本在劃分屬性上的值為空,它應該被分在哪個子結點中歸類樣本

若樣本 x 在劃分屬性 a 上取值未知,則將 x 劃入所有子結點,但是對劃入不同子結點中的 x 賦予不同的權值(不同子結點上的不同權值一般體現為該子結點所包含的資料佔父結點資料集合的比例)

 

2.實踐:(《機器學習實戰》第九章程式碼解析)

  • CART演算法的實現(運用到預剪枝)

##CART演算法的實現程式碼
#匯入資料
def loadDataSet(fileName):
    dataMat = []                #假設最後一列是目標值
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine)) #將每行對映成浮點數
        dataMat.append(fltLine)
    return dataMat

#切分得到兩個子集
def binSplitDataSet(dataSet, feature, value):
    """
    將資料集合切分得到兩個子集
    :param dataSet: 資料集合
    :param feature: 待切分的特徵
    :param value: 該特徵的某個值
    :return: 返回兩個子集
    """
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]#資料集中第feature列的值大於value的分為一組
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]#資料集中第feature列的值小於等於value的分為一組
    return mat0,mat1

#迴歸樹的葉節點生成函式
def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])#在迴歸樹中,返回目標量的均值

#誤差估計函式
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]#總方差=方差*資料集中樣本的個數

#迴歸樹的切分函式
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    """
    找到資料的最佳二元切分方式
    :param dataSet: 資料集
    :param leafType: 對建立葉節點的函式的引用
    :param errType: 對誤差估計函式的引用
    :param ops: 是一個使用者定義的引數構成的元組,用於控制函式的停止時機
    :return: 返回特徵編號和切分特徵值
    """
    tolS = ops[0]#容許的誤差下降值
    tolN = ops[1]#切分的最少樣本數
    #如果所有目標變數都是相同的值則退出
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #tolist()是將陣列或矩陣轉換為列表,set() 函式建立一個無序不重複元素集
        return None, leafType(dataSet) #返回None並同時產生葉節點
    m,n = shape(dataSet)
    #最佳切分也就是使得切分後能達到最低誤差的切分
    S = errType(dataSet)#誤差
    bestS = inf#正無窮
    bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)#切分得到兩個資料集
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue#如果某個子集的大小小於使用者定義的引數tolN,則跳出本次迴圈,繼續下一輪迴圈
            newS = errType(mat0) + errType(mat1)#新誤差
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #如果下降(S-bestS)小於閾值tolS,則不要切分而直接建立葉節點
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #返回None並同時產生葉節點
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)#切分得到兩個資料集
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #如果切分出的資料集的大小小於使用者定義的引數tolN
        return None, leafType(dataSet)#返回None並同時產生葉節點
    return bestIndex,bestValue#返回切分特性和特徵值

#構建迴歸樹
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#假設資料集是NumPy Mat,那麼我們可以陣列過濾
    """
    構建樹
    :param dataSet: 資料集
    :param leafType: 建立葉節點的函式
    :param errType: 誤差計算函式
    :param ops: 是一個使用者定義的引數構成的元組,用於控制函式的停止時機
    :return: 存放樹的資料結構的字典
    """
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#選擇最佳分割,二元切分
    if feat == None: return val #如果切分達到停止條件,返回特徵值
    retTree = {}#字典
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)#切分得到兩子集
    retTree['left'] = createTree(lSet, leafType, errType, ops)#左子樹
    retTree['right'] = createTree(rSet, leafType, errType, ops)#右子樹
    return retTree#存放樹的資料結構的字典
  • 後剪枝演算法實現:

def isTree(obj):#用來判斷當前處理的節點是否是葉結點
    return (type(obj).__name__=='dict')

def getMean(tree):#遞迴從上到下遍歷,找到兩個葉結點就計算平均值返回(對數進行塌陷處理)
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree, testData):#剪枝函式
    """
    剪枝函式
    引數:
        tree -- 待剪枝的樹
        testData -- 測試資料
    返回:
        treeMean -- 合併的結果
        或
        tree -- 不需要剪枝
    """
    # 如果沒有測試資料,就直接把整棵樹合併
    if shape(testData)[0] == 0: return getMean(tree)
    # 如果沒有樹可以合併,則分割節點
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    # 遞迴分割左右子樹
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    # 如果兩棵樹都是葉子節點,則判斷是否要合併
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        # 不合並的誤差
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        # 合併誤差
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        # 如果合併後,誤差減小,則執行合併
        if errorMerge < errorNoMerge: 
            print("merging")
            return treeMean
        # 反之,不執行合併
        else: return tree
    else: return tree

參考

【1】https://www.jianshu.com/p/d80fbec52f09
【2】https://www.cnblogs.com/yonghao/p/5135386.html