1. 程式人生 > >決策樹學習(上)——深度原理剖析及原始碼實現

決策樹學習(上)——深度原理剖析及原始碼實現

引言

本文給大家分享的主題是決策樹(Decision Tree)的原理剖析並附上程式碼實現供大家參考。由於基於決策樹的演算法較多,因此文章分為上下篇。上篇主要剖析決策樹原理、需要掌握的資訊理論知識以及Java原始碼實現等內容。下篇內容包括基於決策樹的ID3、CART以及C4.5等著名演算法的深入比較、理解以及完整程式碼實現。

決策樹是資料探勘以及機器學習領域一個基礎的演算法。在此基礎上產生諸多著名演算法如ID3,CART以及C4.5等。其中C4.5更是被評為資料探勘領域的十大經典演算法。

原理剖析

示例

顧名思義,決策樹是一顆關於決策的樹。舉個簡單的例子稍作解釋,當我們打算去某家餐廳吃飯的時候會有諸多因素影響我們的決定,例如“當前餐廳的顧客多不多?”、“去餐廳的交通路況如何?”、“餐廳型別,中式、法式還是意式?”等等。在這一系列的一步步的思考之後我們會做出最終的決策:去或者不去該餐廳吃飯。當然也可以在考慮某一個或者部分因素之後做出決策,例如對於某些單身汪“是否有異性相約”是一個無比重要的決策因素,那麼對於這些人只考慮一個因素便可以做出最終決策。

我們一起再來看一個例子,銀行在放貸款的時候經常會對借貸人進行綜合考量,最終做出決策是否借貸款給這個人。當然在考核的過程中若有一項重要指標不符合要求,那麼也會立即否決而不考慮其他因素。有如下記錄(實際中往往比這些資料複雜得多):

上表中的資料記錄了銀行以往的借貸歷史中使用者的情況和最後的償還情況。我們可以得出擁有房產的人一般是能償還債務的,而沒有房產的人則需要再考慮其他因素等結論。因此根據這些資料我們可以構造如下決策樹
這裡寫圖片描述
如果此時有一個客戶前來貸款,該客戶沒有房產,單身且年收入只有50K。那麼根據上面的決策樹,銀行可以預測他無法償還債務(圖中藍色虛線),從而否決對其貸款。此外從上面的決策樹,還可以知道是否擁有房產可以在一定程度上決定使用者是否可以償還債務,對借貸業務具有指導意義。

從上面的示例及解釋中我們可以總結出如下結論:決策樹是一種樹形結構,其中每個內部節點表示一個屬性上的測試,每個分支代表一個測試輸出,每個葉節點代表一種類別。

構造決策樹

經過上面的敘述,相信大家已經明白什麼是決策樹以及決策樹的用途,這就解決了我們面對一個新事物時的三要素What,Why and How中的What和Why。那麼如何構造一顆決策樹呢(How)?這是我們的核心問題。在繼續敘述之前,我們需要掌握一些資訊理論的基礎知識。

資訊理論基礎知識

1、熵(entropy)
在資訊理論裡熵叫作資訊量,即熵是對不確定性的度量。從控制論的角度來看,應叫不確定性。資訊理論的創始人夏農在其著作《通訊的數學理論》中提出了建立在概率統計模型上的資訊度量。他把資訊定義為“用來消除不確定性的東西”。在資訊世界,熵越高,則能傳輸越多的資訊,熵越低,則意味著傳輸的資訊越少(什麼鬼,大家可以忽略上述解釋,咬文嚼字什麼的最煩了)。還是舉例說明,假設Kathy在買衣服的時候有顏色,尺寸,款式以及設計年份四種要求,而North只有顏色和尺寸的要求,那麼在購買衣服這個層面上Kathy由於選擇更多因而不確定性因素更大,最終Kathy所獲取的資訊更多,也就是熵更大。所以資訊量=熵=不確定性,通俗易懂。在敘述決策樹時我們用熵表示不純度(Impurity)

根據上面的敘述,可以給出如下熵的數學表示式(定義0*log(0)=0):
這裡寫圖片描述
當然,這個表示式也可以用來量化不純度。依舊舉例解釋一下,以前面的去餐廳吃飯為例說明。假設現在有兩個獨立的決策條件:1.餐廳中顧客數(Patrons),沒有、有一些、滿員;2.餐廳型別,法式、意式、泰式以及快餐廳。我們現在擁有12名顧客(正負樣本各一半)的決策資料,如下圖所示(綠色代表正樣本,紅色代表負樣本):
這裡寫圖片描述
對於這兩種決策條件,在決策之前資料集的熵Entropy=H(N)=-(0.5*log(0.5) + 0.5*log(0.5) )=1。我們說此時的熵值最大,也就是說不純度最小。若資料集中只有一類資料則不純度最小,即資料是“純的”。例如在第一種決策中,當餐廳中沒有顧客(None)的時候,最終的決策是都不去該餐廳(有可能是該餐廳食物太難吃)。此時熵Entropy=H(N)=-(0*log(0) + 1*log(1))=0。我們說此時資料集不純度最低,即資料是純的。

此外,學術界也用基尼係數(Gini):
Gini係數
以及誤差不純度
誤差不純度
來度量不純度。以上三種方式中通常情況下選用熵作為度量不純度的指標。

2、資訊增益(Information Gain)
資料集的一個屬性的資訊增益就是由於使用這個屬性分割樣例而導致的期望熵降低。也就是訓練集D分割之前的資訊熵減去依據某個屬性A分割成若干個子集後的資訊熵。其數學表示式為:
這裡寫圖片描述
舉例說明。在上述決策是否去餐廳吃飯的示例中,原資料集的熵為Entropy=H(6/12, 6/12)=1,在兩種不同條件下的資訊增益分別為:
Gain(Patrons)=1-[2/12H(0,1) + 4/12H(1,0) + 6/12H(2/6, 4/6)]=0.0541
Gani(Type)=1-[2/12H(1/2, 1/2) + 2/12H(1/2, 1/2) + 4/12H(2/4, 2/4) + 4/12H(2/4, 2/4)]=0

因此選用餐廳中顧客數為決策條件能獲得的資訊增益更大。資訊增益越大意味著能將資料集劃分得越簡潔。通俗地解釋,資訊增益越大,在同一條件下子集的熵越小,亦即子集越“純”。這也就是ID3演算法的原理。

3、資訊增益率(Information Gain Ratio)

歷史的程序往往伴隨著新的事物推翻舊的事物。ID3於1975年發明,而在1993年被更好的C4.5演算法取代。

首先給出資訊增益率的數學表示式:
這裡寫圖片描述
其中SplitInformation的數學表示式如下所示,其意義為根據屬性A劃分的各子集所需要的資訊量——。(有些晦澀,稍後舉例說明)
這裡寫圖片描述

之所以資訊增益率作為劃分資料的一種方式出現是由於資訊增益 具有傾向於選擇劃分值多的屬性的缺陷。舉一個極端例子說明,在上述餐廳示例中,若以餐廳類別為決策條件,並且有12個類別,假設最終每個類別中均只有一個潛在客戶的決策。那麼此時每個子集中的熵都為0,資訊增益增益最大,這樣訓練出來的決策樹往往會導致過擬合

我們以上述餐廳示例中的第一種情況為例,計算資訊增益率。前面我們已經計算過Gain(Patrons)=0.0541,下面計算
SplitInformation(S,Patrons)=-(2/12*log(2/12) + 4/12*log(4/12) + 6/12*log(6/12))=0.7887,
GainRatio(S,Patrons)=Gain(Patrons)/SplitInformation(S,Patrons)=0.0686。

更進一步,資訊增益率是如何避免資訊增益 中由於優先選擇值多的屬性而導致過擬合現象的出現?上面我們討論過以餐廳型別為條件並假設有12種不同餐廳,每類餐廳最後僅有一人做出決策的情況。這種情況下資訊增益最大。我們在以顧客數目型別為條件進行決策時,假設僅有兩種顧客數目(而不是上述三種類型)——沒有顧客(None)以及滿員(Full)(並假設資料集均分),並假設此種情況下資訊增益略小於餐廳型別的資訊增益。那麼在算資訊增益率的時前者SplitInformation(S,Types)=log(12),後者的SplitInformation(S,Patrons)=1,從而可能出現資訊增益雖然前者大,但是資訊增益率後者大的情況,這樣便可以避免過擬合的出現。C4.5與ID3演算法的不同在於C4.5使用資訊增益率,而ID3使用資訊增益。

當然資訊增益率也有其不完美的一面,當某個屬性的子集所佔資料集的比重非常大的時候,會出現SplitInformation接近0而資訊增益率異常大的情況。針對這種種情況可以採取某些處理方法,比如先計算每個屬性的增益,然後僅對那些增益高過平均值的屬性應用增益比率測試。

再論構造決策樹

奧卡姆剃刀定律(Occam’s Razor, Ockham’sRazor)又稱“奧康的剃刀”,是由14世紀邏輯學家、聖方濟各會修士奧卡姆的威廉(William of Occam,約1285年至1349年)提出。這個原理稱為“如無必要,勿增實體”,即“簡單有效原理”。正如他在《箴言書注》2卷15題說“切勿浪費較多東西去做,用較少的東西,同樣可以做好的事情。”

經過初探資訊理論 以及掌握一系列數學公式之後,我們繼續討論如何構造決策樹。決策樹的構造是一個貪心的、遞迴的、自頂向下的過程。演算法選用能將資料集劃分最“純”的節點作為當前節點的子節點(關於熵以及不純度的知識在上面小節中已經討論)。決策樹構造步驟如下:

  1. 開始,所有記錄看作一個節點
  2. 遍歷每個變數的每一種分割方式,找到最好的分割點
  3. 分割成兩個節點N1和N2
  4. 對N1和N2分別繼續執行2-3步,直到每個節點足夠“純”為止

當每個節點中的記錄數小於一個閾值的時候演算法停止。需要注意的是當閾值過小,例如為1的時候往往會導致過擬合現象。

剪枝 是解決過擬合的一個有效方法。當樹訓練得過於茂盛的時候會出現在測試集上的效果比訓練集上差不少的現象,即過擬合。可以採用如下兩種剪枝策略:
- 前置裁剪 在構建決策樹的過程時,提前停止。那麼,會將切分節點的條件設定的很苛刻,導致決策樹很短小。結果就是決策樹無法達到最優。實踐證明這種策略無法得到較好的結果。
- 後置裁剪 決策樹構建好後,然後才開始裁剪。採用兩種方法:1)用單一葉節點代替整個子樹,葉節點的分類採用子樹中最主要的分類;2)將一個子樹完全替代另外一顆子樹。當然後置裁剪也同樣存在問題,即計算效率,某些節點在計算後被裁剪會導致計算資源浪費,效率偏低。

至此,已經將決策樹的原理、構造方法和理解決策樹所需要的資訊理論等知識敘述完。有關基於決策樹演算法的講解和程式碼實現將在《決策樹學習(下)》中為大家呈現。

決策樹優缺點

通過上面的討論在這裡給出決策樹的優缺點。
優點:

  • 決策過程接近人的思維習慣。
  • 模型容易解釋,比線性模型具有更好的解釋性。
  • 能清楚地使用圖形化描述模型。
  • 處理定型特徵比較容易。

缺點:

  • 一般來說,決策樹學習方法的準確率不如其他的模型。針對這種情況存在一些解決方案,在後面的文章中為大家講解。
  • 不支援線上學習。當有新樣本來的時候,需要重建決策樹。
  • 容易產生過擬合現象。

決策樹實現Java版

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

public class DecisionTree {
    public static void main(String[] args) throws Exception {
        String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT",
                "CREDIT_RATING" };

        // 讀取樣本集
        Map<Object, List<Sample>> samples = readSamples(attrNames);
        // 生成決策樹
        Object decisionTree = generateDecisionTree(samples, attrNames);
        // 輸出決策樹
        outputDecisionTree(decisionTree, 0, null);
    }

    /**
     * 讀取已分類的樣本集,返回Map:分類 -> 屬於該分類的樣本的列表
     */
    static Map<Object, List<Sample>> readSamples(String[] attrNames) {
        // 樣本屬性及其所屬分類(陣列中的最後一個元素為樣本所屬分類)
        Object[][] rawData = new Object[][] {
                { "<30  ", "High  ", "No ", "Fair     ", "0" },
                { "<30  ", "High  ", "No ", "Excellent", "0" },
                { "30-40", "High  ", "No ", "Fair     ", "1" },
                { ">40  ", "Medium", "No ", "Fair     ", "1" },
                { ">40  ", "Low   ", "Yes", "Fair     ", "1" },
                { ">40  ", "Low   ", "Yes", "Excellent", "0" },
                { "30-40", "Low   ", "Yes", "Excellent", "1" },
                { "<30  ", "Medium", "No ", "Fair     ", "0" },
                { "<30  ", "Low   ", "Yes", "Fair     ", "1" },
                { ">40  ", "Medium", "Yes", "Fair     ", "1" },
                { "<30  ", "Medium", "Yes", "Excellent", "1" },
                { "30-40", "Medium", "No ", "Excellent", "1" },
                { "30-40", "High  ", "Yes", "Fair     ", "1" },
                { ">40  ", "Medium", "No ", "Excellent", "0" } };

        // 讀取樣本屬性及其所屬分類,構造表示樣本的Sample物件,並按分類劃分樣本集
        Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>();
        for (Object[] row : rawData) {
            Sample sample = new Sample();
            int i = 0;
            for (int n = row.length - 1; i < n; i++)
                sample.setAttribute(attrNames[i], row[i]);
            sample.setCategory(row[i]);
            List<Sample> samples = ret.get(row[i]);
            if (samples == null) {
                samples = new LinkedList<Sample>();
                ret.put(row[i], samples);
            }
            samples.add(sample);
        }
        return ret;
    }

    /**
     * 構造決策樹
     */
    static Object generateDecisionTree(
            Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {
        // 如果只有一個樣本,將該樣本所屬分類作為新樣本的分類
        if (categoryToSamples.size() == 1)
            return categoryToSamples.keySet().iterator().next();

        // 如果沒有供決策的屬性,則將樣本集中具有最多樣本的分類作為新樣本的分類,即投票選舉出分類
        if (attrNames.length == 0) {
            int max = 0;
            Object maxCategory = null;
            for (Entry<Object, List<Sample>> entry : categoryToSamples
                    .entrySet()) {
                int cur = entry.getValue().size();
                if (cur > max) {
                    max = cur;
                    maxCategory = entry.getKey();
                }
            }
            return maxCategory;
        }

        // 選取測試屬性
        Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames);
        // 決策樹根結點,分支屬性為選取的測試屬性
        Tree tree = new Tree(attrNames[(Integer) rst[0]]);

        // 已用過的測試屬性不應再次被選為測試屬性
        String[] subA = new String[attrNames.length - 1];
        for (int i = 0, j = 0; i < attrNames.length; i++)
            if (i != (Integer) rst[0])
                subA[j++] = attrNames[i];

        // 根據分支屬性生成分支
        @SuppressWarnings("unchecked")
        Map<Object, Map<Object, List<Sample>>> splits =
        /* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2];
        for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) {
            Object attrValue = entry.getKey();
            Map<Object, List<Sample>> split = entry.getValue();
            Object child = generateDecisionTree(split, subA);
            tree.setChild(attrValue, child);
        }
        return tree;
    }

    /**
     * 選取最優測試屬性。最優是指如果根據選取的測試屬性分支,則從各分支確定新樣本
     * 的分類需要的資訊量之和最小,這等價於確定新樣本的測試屬性獲得的資訊增益最大
     * 返回陣列:選取的屬性下標、資訊量之和、Map(屬性值->(分類->樣本列表))
     */
    static Object[] chooseBestTestAttribute(
            Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {
        int minIndex = -1; // 最優屬性下標
        double minValue = Double.MAX_VALUE; // 最小資訊量
        Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最優分支方案

        // 對每一個屬性,計算將其作為測試屬性的情況下在各分支確定新樣本的分類需要的資訊量之和,選取最小為最優
        for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) {
            int allCount = 0; // 統計樣本總數的計數器
            // 按當前屬性構建Map:屬性值->(分類->樣本列表)
            Map<Object, Map<Object, List<Sample>>> curSplits =
            /* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>();
            for (Entry<Object, List<Sample>> entry : categoryToSamples
                    .entrySet()) {
                Object category = entry.getKey();
                List<Sample> samples = entry.getValue();
                for (Sample sample : samples) {
                    Object attrValue = sample
                            .getAttribute(attrNames[attrIndex]);
                    Map<Object, List<Sample>> split = curSplits.get(attrValue);
                    if (split == null) {
                        split = new HashMap<Object, List<Sample>>();
                        curSplits.put(attrValue, split);
                    }
                    List<Sample> splitSamples = split.get(category);
                    if (splitSamples == null) {
                        splitSamples = new LinkedList<Sample>();
                        split.put(category, splitSamples);
                    }
                    splitSamples.add(sample);
                }
                allCount += samples.size();
            }

            // 計算將當前屬性作為測試屬性的情況下在各分支確定新樣本的分類需要的資訊量之和
            double curValue = 0.0; // 計數器:累加各分支
            for (Map<Object, List<Sample>> splits : curSplits.values()) {
                double perSplitCount = 0;
                for (List<Sample> list : splits.values())
                    perSplitCount += list.size(); // 累計當前分支樣本數
                double perSplitValue = 0.0; // 計數器:當前分支
                for (List<Sample> list : splits.values()) {
                    double p = list.size() / perSplitCount;
                    perSplitValue -= p * (Math.log(p) / Math.log(2));
                }
                curValue += (perSplitCount / allCount) * perSplitValue;
            }
            // 選取最小為最優
            if (minValue > curValue) {
                minIndex = attrIndex;
                minValue = curValue;
                minSplits = curSplits;
            }
        }
        return new Object[] { minIndex, minValue, minSplits };
    }

    /**
     * 將決策樹輸出到標準輸出
     */
    static void outputDecisionTree(Object obj, int level, Object from) {
        for (int i = 0; i < level; i++)
            System.out.print("|-----");
        if (from != null)
            System.out.printf("(%s):", from);
        if (obj instanceof Tree) {
            Tree tree = (Tree) obj;
            String attrName = tree.getAttribute();
            System.out.printf("[%s = ?]\n", attrName);
            for (Object attrValue : tree.getAttributeValues()) {
                Object child = tree.getChild(attrValue);
                outputDecisionTree(child, level + 1, attrName + " = "
                        + attrValue);
            }
        } else {
            System.out.printf("[CATEGORY = %s]\n", obj);
        }
    }

    /**
     * 樣本,包含多個屬性和一個指明樣本所屬分類的分類值
     */
    static class Sample {

        private Map<String, Object> attributes = new HashMap<String, Object>();

        private Object category;

        public Object getAttribute(String name) {
            return attributes.get(name);
        }

        public void setAttribute(String name, Object value) {
            attributes.put(name, value);
        }

        public Object getCategory() {
            return category;
        }

        public void setCategory(Object category) {
            this.category = category;
        }

        public String toString() {
            return attributes.toString();
        }

    }

    /**
     * 決策樹(非葉結點),決策樹中的每個非葉結點都引導了一棵決策樹
     * 每個非葉結點包含一個分支屬性和多個分支,分支屬性的每個值對應一個分支,該分支引導了一棵子決策樹
     */
    static class Tree {

        private String attribute;

        private Map<Object, Object> children = new HashMap<Object, Object>();

        public Tree(String attribute) {
            this.attribute = attribute;
        }

        public String getAttribute() {
            return attribute;
        }

        public Object getChild(Object attrValue) {
            return children.get(attrValue);
        }

        public void setChild(Object attrValue, Object child) {
            children.put(attrValue, child);
        }

        public Set<Object> getAttributeValues() {
            return children.keySet();
        }
    }
}

參考文獻及推薦閱讀