1. 程式人生 > >【機器學習筆記27】CART演算法-迴歸樹和分類樹

【機器學習筆記27】CART演算法-迴歸樹和分類樹

基本概念

分類和迴歸樹(classification and regression tree, CART) 是應用廣泛的決策樹學習方法,由特徵選擇、樹的生成和剪枝組成,既可以用做分類也可以用作迴歸。

迴歸樹
迴歸樹的定義

假設X和Y分別作為輸入和輸出變數,那麼存在訓練集 D={(x1,y1),(x2,y2)...(xN,yN)}D = \{(x_1, y_1), (x_2, y_2) ... (x_N, y_N) \} 一個迴歸樹對應其輸入空間(特徵)的劃分和這個劃分上的輸入值。 數學定義: 存在M個分類,每個分類的單元為R

MR_M,且該單元的輸出為cMc_M,我們有迴歸樹模型f(x)=m=1McMI(xRM)f(x)=\sum\limits_{m=1}^{M}c_MI(x \in R_M) ,這裡II代表那個分類的集合。

於是我們可以用平方誤差(yif(xi))2\sum(y_i - f(x_i))^2來表示迴歸樹訓練時的誤差。

迴歸樹的生成

第一步: 選擇切分變數j和切分點s,即選用特徵j和特徵上的一個閾值s將輸入資料劃分成一個二叉決策樹。這一步需要遍歷所有的特徵及其閾值,獲取最優解,數學表示如下:

1)目標是取兩個集合:R

1(j,s)={xx(j)<s}R2(j,s)={xxj>s}R_1(j,s)=\{x|x^{(j)}<s\} \quad R_2(j,s)=\{x|x^{j}>s\} 2)求解誤差最小值 minj,s[minc1xR1(yic1)2+minc2xR2(yic2)2]\min\limits_{j,s}[\min\limits_{c_1}\sum\limits_{x \in R_1}(y_i - c1)^2 + \min\limits_{c_2}\sum\limits_{x \in R_2}(y_i - c2)^2]
第二步:這裡設定的輸出值為訓練資料輸出的均方差 $c_1 = ave(y_i | x_i \in R_1) \quad c_2 = ave(y_i | x_i \in R_2) $

第三步: 對已經劃分的兩個子區域,再次重複步驟1,2,知道滿足停止條件,生成有M個區域的決策樹。

分類樹
基尼指數

分類問題中假設有K個類,樣本點屬於k類的概率為pkp_k,則概率分佈的基尼係數為: Gini(p)=k=1Kpk(1pk)=1k=1Kpk2Gini(p)=\sum\limits_{k=1}^{K}p_k(1 - p_k) = 1 - \sum\limits_{k=1}^{K}p_k^2

對於給定的樣本集合D,其基尼指數為: Gini(D)=1k=1K(CkD)2Gini(D)=1-\sum\limits_{k=1}^{K}(\dfrac{|C_k|}{|D|})^2,這裡CkC_k表示D中屬於k類的數量。

基尼係數表示了集合D的不確定性,Gini(D, A)表示D作了A分割成D1D2D_1 \quad D_2後的不確定性,基尼係數越大,不確定性越大。 其中Gini(D,A)=D1DGini(D1)+D2DGini(D2)Gini(D, A)=\dfrac{D_1}{D}Gini(D_1) + \dfrac{D_2}{D}Gini(D_2)

分類樹的訓練

這個和迴歸樹的方式基本一樣,只是損失定義從最小二乘轉變為基尼指數。

第一步: 對輸入資料D的每一個特徵和可能的取值,做分類A並計算Gini(D, A) 第二步: 取Gini(D, A)最小的一組資料,作為本次的最有特徵和最優切分點 第三步: 對上述劃分的結果,迭代步驟1,2 第四步: 生成CART決策樹(分類)

程式碼實現(迴歸樹,基於sklearn)
# -*- coding: utf-8 -*-
import numpy  as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor

def _test_cart_regresssion():

    index  = np.linspace(0, 10, 200);

    delta  = np.random.rand()

    index.shape = (200, 1)

   
    y_test = 0.3*index + np.sin(index) + delta

    clf = DecisionTreeRegressor(max_depth=4) 
    """
    這裡的深度為4,相當於迴歸的時候分了4次,劃分為2^4=16個區域
    """

    clf.fit(index, y_test)

    y_pred = clf.predict(index)

    plt.scatter(index, y_test,  color='black',marker='o')
    plt.plot(index, y_pred, color='blue', linewidth=2)

    plt.xticks(())
    plt.yticks(())

    plt.show()


    pass;


"""
說明:

CART演算法迴歸樹及決策樹的例子《CART演算法》

作者:fredric

日期:2018-11-4

"""
if __name__ == "__main__":

    _test_cart_regresssion()

在這裡插入圖片描述