1. 程式人生 > >資料探勘學習筆記-決策樹演算法淺析(含Java實現)

資料探勘學習筆記-決策樹演算法淺析(含Java實現)

目錄

一、通俗理解決策樹演算法原理

二、舉例說明演算法執行過程

三、Java實現

本文基於書籍《資料探勘概念與技術》,由於剛接觸Data Mining,所以可能有理解不到位的情況,記錄學習筆記,提升自己對演算法的理解。

程式碼下方有,如果有金幣的童鞋可以貢獻一下給無恥的我一枚:

程式碼傳送門:http://download.csdn.net/detail/adiaixin123456/9416398

一、通俗理解決策樹演算法

決策樹演算法主要用於分類,分類顧名思義就是將不同的事物進行分類,比如對於銀行貸款的客戶來說,就可以分為是安全的客戶以及存在潛在風險的客戶,我們可以根據使用者的型別來決定是否給予貸款,以及給多少額度。

決策樹是一顆多叉樹,它是一種監督學習(監督學習就是提供的訓練資料中對每條資料都提供了類標號,例如告訴我們哪些人買了電腦,哪些人沒有買)。是通過事務的多個不同屬性來對事務進行分類的,這裡舉一個書上的例子,即電商想對使用者是否會買電腦進行預測分析,我們就要對已經買過電腦和其他沒有買電腦的人的各項屬性來進行分類,發現這個買電腦人有哪些的特性,從而對其他潛在客戶進行預測,是否滿足這些特性,去推薦使用者電腦。可以用來分類的屬性有很多,這裡只給出書上例子中的,包括年齡、收入、是否是學生、信用這四個維度,當然也可以包含其他的屬性,例如職業、居住地、性別等等。這些屬性在使用者畫像中也稱為標籤。

生成一個決策樹主要的步驟1.學習:通過決策樹生成算髮分析訓練資料,生成決策樹。2.資料校驗:通過校驗資料評估這個決策樹正確率如何,如果可以接受,就可以用於新資料的分類。

二、演算法執行過程

我們先給出一個決策樹的例子圖如下,表示了使用者是否購買電腦的分類。一目瞭然,在使用時,通過使用者的屬性的值先比較使用者的年齡然後繼續向子節點比較,最後得出結果。


2.1在描述演算法之前,我們首先要確定幾件事情:

(1).屬性比較的先後順序,即那個屬性先比較、哪些屬性後比較,這裡不同的決策樹方法如ID3使用的屬性的增益,而C4.5作為ID3的改進,選擇了屬性的增益率。大體的方向都差不多,就是選擇出哪個屬性對我們要分類的影響最大,就把它放在前面去比較,比如在買電腦上面,性別對分類的影響不大,比如各佔50%。而不同的年齡中買電腦的比例就很大,比如老年人購買的機率就很小,而中年人的購買的機率就很大,畢竟工作需要,這樣我們就說年齡比性別的增益大,當然這個是有個公式來計算的,我們後面會說。

(2).屬性的型別可以分為連續的或離散的,其中離散的表示屬性的值的取值範圍是可數的,比如標稱(類似列舉)、二元屬性(類似布林型別),連續的比如年齡、收入等,對於年齡,因為可以根據時間的單位進行無限的劃分,比如按照年、月數、天數、秒、毫秒等等,決策樹生成演算法需要離散的屬性,對於連續的屬性要進行離散化,這裡我們需要將年齡轉化為離散的,比如如上操作,將年齡分為青少年、中年和老年,這個需要由專家來人為劃分,或使用其他公式,輸入資料的預處理部分

(3).如果想要一個理想的決策樹,那麼我們就需要大量的真實資料來構建與驗證,否則也許會對決策進行誤導。

2.2演算法描述:

我們使用書上的例子:如下表為客戶的資料,並且已經對其做了類標記,即客戶有木有購買過電腦


書中決策樹的生成演算法如下,可以先跳過:


舉例說明:

1.演算法引數:

(1)帶類標號的資料集D,即上面表中的資料,14條。

(2)使用者分類的屬性集合,該例子中為年齡、收入、是否為學生、信用評級。

(3)找到最好的劃分屬性的方法,即我們之前提過的如果確定哪個屬性優先進行比較,這裡使用ID3的,屬性增益來計算。

2.演算法過程:

1.在屬性的集合中選擇一個最好的劃分屬性,作為根分裂節點,屬性A的增益計算公式為

其中

,期望資訊,又稱為熵

m為類的值種類個數,本例中,m為2(buys_computer只有買或不買兩種)

pi為類每個值出現的概率,p1=9/14,即buys_computer是yes的個數為9,總共資料集個數為14。同理p2=5/14

來表示屬性每個分割槽(分支)的純度,所有的資料都是同一類就最純,InfoA(D)=0

v為屬性值分割槽個數,已屬性age為例,v=3,因為只有youth,middle_aged,senior三種

|Dj|為每個分割槽資料的個數,D1為youth:5,D2=4,D3=5

下面式子中2/5表示在age=youth中有2個買了電腦,3/5,3個沒有買


2.選出來增益最高的屬性後就將該屬性作為節點(本例中為age),將屬性集合刪除age,並對資料集拆分後繼續向子節點計算


3.重複1、2步驟,直到

(1)資料集中所有的資料都屬於一類,設為該結點為葉子,值設定為資料中類的值,如該例子中所有age=middle_aged的都購買了電腦,值都設定為YES

(2)屬性集合為空,設該結點為葉子,值為資料中類的數量最多的分類的值,即對於buys_computer,yes的個數 > no的個數 ? yes : no

本例中最後得出的決策樹就為最開始舉例的那棵,由於資料太少以及屬性選擇的少,導致income這個屬性沒有用到,就已經分好了類:



三、演算法Java實現

3.1目錄結構(忽略的我英文名稱...)

專案的目錄結構分為四個資料夾algorithm,common,data,test
(1)algorithm為演算法,包括DecisionTree(決策樹生成演算法)、IAttrSelector(最佳分裂點屬性選擇演算法介面)、BaseAttrSelector(基礎的屬性選擇演算法實現)
(2)common為公用類,只包含了表示多叉樹的類TreeNode
(3)data為資料,包含了BaseRecord(基礎記錄,這裡只有一個屬性,就是要分類的屬性Boolean的,其他資料庫實體都應該繼承該類)
HummanAttrRecord(描述使用者的屬性類,包括收入、年齡、是否為學生、信用評級)、 
EmAgeLevel(年齡列舉類)、EmCreditRate(信用列舉類)、EmIncome(收入列舉類)。
(4)test為測試類

3.2類檔案列表

package com.adi.datamining.algorithm;

import com.adi.datamining.data.BaseRecord;
import com.adi.datamining.common.TreeNode;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/**
 * Created by wudi on 2016/1/22.
 */
public class DecisionTree {


    IAttrSelector selector;

    public DecisionTree(IAttrSelector selector) {
        this.selector = selector;
    }

    /**建立決策樹*/
    public TreeNode createTree(List<BaseRecord> records, Set<Field> attrSet) {
        if(null == records || records.size() < 1)
            return null;
        TreeNode node = new TreeNode();
        //1.如果所有的記錄分類屬性值都相同,如果全部相同則直接返回分類屬性值
        if(isAllInSameClass(records)){
            node.setAttrName(String.valueOf(records.get(0).getDecisionAttr()));
            return node;
        }
        //2.如果屬性列表為空,統計記錄集合中正負樣例個數,正>負?true:false
        if(null == attrSet || 0 == attrSet.size()){
            node.setAttrName(String.valueOf(getMostClass(records)));
            return node;
        }
        //3.選擇出來增益最大的屬性
        Field bestField = selector.select(records,attrSet);
        //4.根據最好屬性的值分為多個分支
        List<List<BaseRecord>> splitValues = splitRecords(records, bestField);
        List<TreeNode> children = new ArrayList<TreeNode>(splitValues.size());
        attrSet.remove(bestField);
        //5.遍歷子節點
        for (List<BaseRecord> recordList : splitValues) {
            children.add(createTree(recordList, attrSet));
        }
        node.setTreeNodeList(children);
        node.setAttrName(bestField.getName());
        return node;
    }

    /**根據屬性的值分不同列表*/
    private List<List<BaseRecord>> splitRecords(List<BaseRecord> records, Field field) {
        List<List<BaseRecord>> result = new ArrayList<List<BaseRecord>>();
        try {
            field.setAccessible(true);
        outerLoop :
            for(BaseRecord record : records) {
                Object value = field.get(record);
                for(List<BaseRecord> recordList : result) {
                    if(field.get(recordList.get(0)).equals(value)) {
                        recordList.add(record);
                        continue outerLoop;
                    }
                }
                List<BaseRecord> recordList = new ArrayList<BaseRecord>();
                recordList.add(record);
                result.add(recordList);
            }
        } catch (Exception ex) {
            System.out.println("method access exception");
        }

        return result;
    }

    /**根據列表中分類的正負樣例個數決定葉子節點為true or false*/
    private Boolean getMostClass(List<BaseRecord> records) {
        int positCount = 0;
        int negatCount = 0;
        for(BaseRecord record : records) {
            if(record.getDecisionAttr())
                ++positCount;
            else
                ++negatCount;
        }
        return positCount > negatCount ? true : false;
    }

    /**判斷所有記錄是否具有相同的分類值*/
    private boolean isAllInSameClass(List<BaseRecord> records) {
        Boolean buyComp = records.get(0).getDecisionAttr();
        for(BaseRecord record : records) {
            if(!buyComp.equals(record.getDecisionAttr()))
                return false;
        }
        return true;
    }

}
package com.adi.datamining.algorithm;

import com.adi.datamining.data.BaseRecord;

import java.lang.reflect.Field;
import java.util.List;
import java.util.Set;

/**
 * Created by wudi on 2016/1/23.
 */
public interface IAttrSelector {
    public Field select(List<BaseRecord> records, Set<Field> atrrs);
}
package com.adi.datamining.algorithm;

import com.adi.datamining.data.BaseRecord;

import java.lang.reflect.Field;
import java.util.*;

/**
 * Created by wudi10 on 2016/1/23.
 */
public class BaseAttrSelector implements IAttrSelector{
    /**通過記錄集合與記錄的屬性集合,挑選出屬性中增益度最大的屬性*/
    @Override
    public Field select(List<BaseRecord> records, Set<Field> atrrs){
        Field bestField = null;
        Double highestScore = 0D;
        Double setInfo = entropy(records);
        for(Field field : atrrs) {
            Double gainScore = setInfo - infoScore( records, field);
                if(gainScore > highestScore) {
                highestScore = gainScore;
                bestField = field;
            }
        }
        return bestField;
    }
    /**根據記錄列表求關於所求類的熵,此方法中要分的類是DcisionAtrr*/
    private Double entropy(List<BaseRecord> records) {
        Double positCount = 0D;
        Double negatCount = 0D;
        for(BaseRecord record : records) {
            if(record.getDecisionAttr())
                ++positCount;
            else
                ++negatCount;
        }
        return - positCount/records.size()* log2N(positCount / records.size())
                - negatCount/records.size()* log2N(negatCount / records.size());

    }

    /**log2(N), log 以2為底N的對數*/
    private Double log2N(Double d) {
        return Math.log(d) / Math.log(2.0);
    }

    /**求某個屬性對於分類DecisionAttr的期望分數,公式見<資料探勘概念與技術>中決策樹那節*/
    private Double infoScore(List<BaseRecord> records, Field field) {
        Double infoScore = 0D;
        try {
            //1.求該屬性每個值對於分類的正負樣例個數,即有多少是true,多少個false;
            Map<Object,List<Integer>> count4Values = new HashMap<Object,List<Integer>>();//key:存放該屬性不同值,value:長度為2,存放該屬性值對分類正負樣例數
            Integer size = records.size();
            field.setAccessible(true);
            for(BaseRecord record : records) {
                Object attrValue = field.get(record);
                List<Integer> countList = count4Values.get(attrValue);
                if(countList == null) {
                    countList = new ArrayList<Integer>(2);
                    countList.add(0,0);
                    countList.add(1,0);
                }
                if(record.getDecisionAttr()){
                    countList.set(0,countList.get(0) + 1);
                } else {
                    countList.set(1,countList.get(1) + 1);
                }
                count4Values.put(attrValue, countList);
            }

            //2.遍歷map算出期望值
            for(Object key : count4Values.keySet()) {
                List<Integer> countList = count4Values.get(key);
                double positCount = countList.get(0);
                double negatCount = countList.get(1);
                if(positCount == 0 || negatCount == 0) //對於正負樣例個數為0的情況,視為無效,對分類影響最大,分數為0;
                    continue;
                double valueCount = positCount + negatCount;
                infoScore += valueCount/size * ( - (positCount/valueCount) * log2N(positCount / valueCount)
                        - (negatCount/valueCount) * log2N(negatCount/valueCount));
            }

        } catch (Exception ex) {
            System.out.println("method access exception");
        }
        return infoScore;

    }

}
package com.adi.datamining.common;

import java.util.List;

/**
 * Created by wudi on 2016/1/22.
 * 多叉樹
 */
public class TreeNode {
    private String attrName;
    private List<TreeNode> treeNodeList;

    public TreeNode(){}

    public TreeNode(String attrName, List<TreeNode> treeNodeList) {
        this.attrName = attrName;
        this.treeNodeList = treeNodeList;
    }

    public String getAttrName() {
        return attrName;
    }

    public void setAttrName(String attrName) {
        this.attrName = attrName;
    }

    public List<TreeNode> getTreeNodeList() {
        return treeNodeList;
    }

    public void setTreeNodeList(List<TreeNode> treeNodeList) {
        this.treeNodeList = treeNodeList;
    }

    public void print(int level) {
        if(null == this)
            return;
        for (int i=0; i<level;++i)
            System.out.print("-");
        System.out.println(this.attrName);
        ++level;
        if (null != this.getTreeNodeList())
            for (TreeNode node : this.getTreeNodeList()) {
                node.print(level);
            }
    }
}
package com.adi.datamining.data;

/**
 * Created by wudi on 2016/1/23.
 */
public class BaseRecord {

    private Boolean decisionAttr;

    public BaseRecord(Boolean decisionAttr) {
        this.decisionAttr = decisionAttr;
    }

    public Boolean getDecisionAttr() {
        return decisionAttr;
    }

    public void setDecisionAttr(Boolean decisionAttr) {
        this.decisionAttr = decisionAttr;
    }




}
package com.adi.datamining.data;

/**
 * Created by wudi on 2016/1/22.
 */
public class HumanAttrRecord extends BaseRecord{
    private EmAgeLevel age;
    private EmIncome income;
    private Boolean isStudent;
    private EmCreditRate creditRate;


    public HumanAttrRecord(EmAgeLevel age, EmIncome income, Boolean isStudent, EmCreditRate creditRate, Boolean decisionAttr) {
        super(decisionAttr);
        this.age = age;
        this.income = income;
        this.isStudent = isStudent;
        this.creditRate = creditRate;

    }

    public EmAgeLevel getAge() {
        return age;
    }

    public void setAge(EmAgeLevel age) {
        this.age = age;
    }

    public EmIncome getIncome() {
        return income;
    }

    public void setIncome(EmIncome income) {
        this.income = income;
    }

    public Boolean getIsStudent() {
        return isStudent;
    }

    public void setIsStudent(Boolean isStudent) {
        this.isStudent = isStudent;
    }

    public EmCreditRate getCreditRate() {
        return creditRate;
    }

    public void setCreditRate(EmCreditRate creditRate) {
        this.creditRate = creditRate;
    }

}
package com.adi.datamining.data;

/**
 * Created by wudi10 on 2016/1/22.
 */
public enum EmAgeLevel {

    SENIOR(1, "高齡人"),
    MIDDLE_AGED(2,"中齡人"),
    YOUTH(3,"年輕人");

    private final Integer level;
    private final String desc;
    private EmAgeLevel(Integer level, String desc) {this.level = level;this.desc = desc;}

    public Integer getLevel(){return this.level;}
}

package com.adi.datamining.data;

/**
 * Created by wudi10 on 2016/1/22.
 */

public enum  EmCreditRate {

    EXCELLENT(1, "優秀"),
    FAIR(2,"正常");

    private final Integer level;
    private final String desc;
    private EmCreditRate(Integer level, String desc) {this.level = level;this.desc = desc;}

    public Integer getLevel(){return this.level;}

}
package com.adi.datamining.data;

/**
 * Created by wudi10 on 2016/1/22.
 */
public enum  EmIncome {

    HIGH(1, "高收入"),
    MEDIUM(2,"中收入"),
    LOW(3,"低收入");

    private final Integer level;
    private final String desc;
    private EmIncome(Integer level, String desc) {this.level = level;this.desc = desc;}

    public Integer getLevel(){return this.level;}

}

package test;

import com.adi.datamining.algorithm.BaseAttrSelector;
import com.adi.datamining.algorithm.DecisionTree;
import com.adi.datamining.algorithm.IAttrSelector;
import com.adi.datamining.common.TreeNode;
import com.adi.datamining.data.*;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * Created by wudi on 2016/1/23.
 */
public class Test {
    public static void main(String[] arr) {
        List<BaseRecord> records = new ArrayList<BaseRecord>();
        HumanAttrRecord record0 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.HIGH,false, EmCreditRate.FAIR,false);
        HumanAttrRecord record1 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.HIGH,false, EmCreditRate.EXCELLENT,false);
        HumanAttrRecord record2 = new HumanAttrRecord(EmAgeLevel.MIDDLE_AGED, EmIncome.HIGH,false, EmCreditRate.FAIR,true);
        HumanAttrRecord record3 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.MEDIUM,false, EmCreditRate.FAIR,true);
        HumanAttrRecord record4 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.LOW,true, EmCreditRate.FAIR,true);
        HumanAttrRecord record5 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.LOW,true, EmCreditRate.EXCELLENT,false);
        HumanAttrRecord record6 = new HumanAttrRecord(EmAgeLevel.MIDDLE_AGED, EmIncome.LOW,true, EmCreditRate.EXCELLENT,true);
        HumanAttrRecord record7 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.MEDIUM,false, EmCreditRate.FAIR,false);
        HumanAttrRecord record8 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.LOW,true, EmCreditRate.FAIR,true);
        HumanAttrRecord record9 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.MEDIUM,true, EmCreditRate.FAIR,true);
        HumanAttrRecord record10 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.MEDIUM,true, EmCreditRate.EXCELLENT,true);
        HumanAttrRecord record11 = new HumanAttrRecord(EmAgeLevel.MIDDLE_AGED, EmIncome.MEDIUM,false, EmCreditRate.EXCELLENT,true);
        HumanAttrRecord record12 = new HumanAttrRecord(EmAgeLevel.MIDDLE_AGED, EmIncome.HIGH,true, EmCreditRate.FAIR,true);
       /* HumanAttrRecord record13 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.LOW,false, EmCreditRate.EXCELLENT,false);
        HumanAttrRecord record14 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.MEDIUM,false, EmCreditRate.FAIR,false);
        HumanAttrRecord record15 = new HumanAttrRecord(EmAgeLevel.MIDDLE_AGED, EmIncome.MEDIUM,true, EmCreditRate.EXCELLENT,true);
        HumanAttrRecord record16 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.LOW,false, EmCreditRate.EXCELLENT,true);
        HumanAttrRecord record17 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.HIGH,true, EmCreditRate.EXCELLENT,true);
        HumanAttrRecord record18 = new HumanAttrRecord(EmAgeLevel.YOUTH, EmIncome.MEDIUM,false, EmCreditRate.EXCELLENT,false);
        HumanAttrRecord record19 = new HumanAttrRecord(EmAgeLevel.SENIOR, EmIncome.LOW,false, EmCreditRate.FAIR,false);
*/
        records.add(record0);
        records.add(record1);
        records.add(record2);
        records.add(record3);
        records.add(record4);
        records.add(record5);
        records.add(record6);
        records.add(record7);
        records.add(record8);
        records.add(record9);
        records.add(record10);
        records.add(record11);
        records.add(record12);
  /*      records.add(record13);
        records.add(record14);
        records.add(record15);
        records.add(record16);
        records.add(record17);
        records.add(record18);
        records.add(record19);*/


        Set<Field> fieldSet = new HashSet<Field>();
        Field[] fields = HumanAttrRecord.class.getDeclaredFields();
        for (Field field : fields) {
            if(field.getName().equals("decisionAttr")) continue;;
            fieldSet.add(field);
        }

        IAttrSelector selector = new BaseAttrSelector();
        DecisionTree decisionTree = new DecisionTree(selector);
        TreeNode root = decisionTree.createTree(records,fieldSet);
        if(null != root) {
            root.print(0);
        }
    }
}

3.3執行Test即可得到結果如下:(請忽略我的顯示方式....)