1. 程式人生 > >寶具滑 / JS簡單實現決策樹(ID3演算法)

寶具滑 / JS簡單實現決策樹(ID3演算法)

<script> 
// 文章: https://www.jianshu.com/p/2b50a98cd75c
    function DecisionTree(config) {
        if (typeof config == "object" && !Array.isArray(config)) this.training(config);
    };
    DecisionTree.prototype = {
        _predicates: {//分割函式
            '==': function (a, b) { return a == b },//針對非數字值的比較
            '>=': function (a, b) { return a >= b }//針對數值的比較
        },
        //統計屬性值在資料集中的次數
        countUniqueValues(items, attr) {
            var counter = {};// 獲取不同的結果值 與出現次數
            for (var i of items) {
                if (!counter[i[attr]]) counter[i[attr]] = 0;
                counter[i[attr]] += 1;
            }
            return counter;
        },
        //獲取物件中值最大的Key  假設 counter={a:9,b:2} 得到 "a" 
        getMaxKey(counter) {
            var mostFrequentValue;
            for (var k in counter) {
                if (!mostFrequentValue) mostFrequentValue = k;
                if (counter[k] > counter[mostFrequentValue]) {
                    mostFrequentValue = k;
                }
            };
            return mostFrequentValue;
        },
        //尋找最頻繁的特定屬性值
        mostFrequentValue(items, attr) {
            return this.getMaxKey(this.countUniqueValues(items, attr));//計算值的出現數
        },
        //根據屬性切割資料集 
        split(items, attr, predicate, pivot) {
            var data = {
                match: [],//適合的資料集
                notMatch: []//不適合的資料集
            }
            for (var item of items) { //遍歷訓練集  
                if (predicate(item[attr], pivot)) {//比較是否滿足條件
                    data.match.push(item);
                } else {
                    data.notMatch.push(item);
                }
            };
            return data;
        },
        //計算熵
        entropy(items, attr) {
            var counter = this.countUniqueValues(items, attr);//計算值的出現數
            var p, entropy = 0;//H(S)=entropy=∑(P(Xi)(log2(P(Xi))))
            for (var i in counter) {//entropy+=-(P(Xi)(log2(P(Xi))))
                p = counter[i] / items.length;//P(Xi)概率值
                entropy += -p * Math.log2(p);
            }
            return entropy;
        },
        buildDecisionTree(config) {
            var trainingSet = config.trainingSet;//訓練集
            var minItemsCount = config.minItemsCount;//訓練集項數
            var categoryAttr = config.categoryAttr;//用於區分的類別屬性
            var entropyThrehold = config.entropyThrehold;//熵閾值
            var maxTreeDepth = config.maxTreeDepth;//遞迴深度
            var ignoredAttributes = config.ignoredAttributes;//忽略的屬性
            // 樹最大深度為0 或訓練集的大小 小於指定項數 終止樹的構建過程
            if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) {
                return { category: this.mostFrequentValue(trainingSet, categoryAttr) };
            }
            //初始計算 訓練集的熵
            var initialEntropy = this.entropy(trainingSet, categoryAttr);//<===H(S)
            //訓練集熵太小 終止
            if (initialEntropy <= entropyThrehold) {
                return { category: this.mostFrequentValue(trainingSet, categoryAttr) };
            }
            var alreadyChecked = [];//標識已經計算過了
            var bestSplit = { gain: 0 };//儲存當前最佳的分割節點資料資訊
            //遍歷資料集
            for (var item of trainingSet) {
                // 遍歷項中的所有屬性
                for (var attr in item) {
                    //跳過區分屬性與忽略屬性
                    if ((attr == categoryAttr) || (ignoredAttributes.indexOf(attr) >= 0)) continue;
                    var pivot = item[attr];// 當前屬性的值 
                    var predicateName = ((typeof pivot == 'number') ? '>=' : '=='); //根據資料型別選擇判斷條件
                    var attrPredPivot = attr + predicateName + pivot;
                    if (alreadyChecked.indexOf(attrPredPivot) >= 0) continue;//已經計算過則跳過
                    alreadyChecked.push(attrPredPivot);//記錄
                    var predicate = this._predicates[predicateName];//匹配分割方式
                    var currSplit = this.split(trainingSet, attr, predicate, pivot);
                    var matchEntropy = this.entropy(currSplit.match, categoryAttr);//  H(match) 計算分割後合適的資料集的熵
                    var notMatchEntropy = this.entropy(currSplit.notMatch, categoryAttr);// H(on match) 計算分割後不合適的資料集的熵
                    //計算資訊增益: 
                    // IG(A,S)=H(S)-(∑P(t)H(t))) 
                    // t為分裂的子集match(匹配),on match(不匹配)
                    // P(match)=match的長度/資料集的長度
                    // P(on match)=on match的長度/資料集的長度
                    var iGain = initialEntropy - ((matchEntropy * currSplit.match.length
                        + notMatchEntropy * currSplit.notMatch.length) / trainingSet.length);
                    //不斷匹配最佳增益值對應的節點資訊
                    if (iGain > bestSplit.gain) {
                        bestSplit = currSplit; 
                        bestSplit.predicateName = predicateName;
                        bestSplit.predicate = predicate;
                        bestSplit.attribute = attr;
                        bestSplit.pivot = pivot;
                        bestSplit.gain = iGain;
                    }
                }
            }

            // 找不到最優分割
            if (!bestSplit.gain) {
                return { category: this.mostFrequentValue(trainingSet, categoryAttr) };
            }
            // 遞迴繫結子樹枝
            config.maxTreeDepth = maxTreeDepth - 1;//減小1深度
            config.trainingSet = bestSplit.match;//將切割 match 訓練集作為下一節點的訓練集
            var matchSubTree = this.buildDecisionTree(config);//遞迴匹配子樹節點
            config.trainingSet = bestSplit.notMatch;//將切割 notMatch 訓練集作為下一節點的訓練集
            var notMatchSubTree = this.buildDecisionTree(config);//遞迴匹配子樹節點 
            return  {
                attribute: bestSplit.attribute,
                predicate: bestSplit.predicate,
                predicateName: bestSplit.predicateName,
                pivot: bestSplit.pivot,
                match: matchSubTree,
                notMatch: notMatchSubTree,
                matchedCount: bestSplit.match.length,
                notMatchedCount: bestSplit.notMatch.length
            };
        },
        training(config) {
            this.root = this.buildDecisionTree({
                trainingSet: config.trainingSet,//訓練集
                ignoredAttributes: config.ignoredAttributes || [],// 被忽略的屬性比如:姓名、名稱之類的
                categoryAttr: config.categoryAttr || 'category',//用於區分的類別屬性
                minItemsCount: config.minItemsCount || 1,//最小項數量
                entropyThrehold: config.entropyThrehold || 0.01,//熵閾值
                maxTreeDepth: config.maxTreeDepth || 70//遞迴的最大深度 
            });
        },
        //預測 測試
        predict(data) {
            var attr, value, predicate, pivot;
            var tree = this.root;
            while (true) {
                if (tree.category) {
                    return tree.category;
                }
                attr = tree.attribute;
                value = data[attr];
                predicate = tree.predicate;
                pivot = tree.pivot;
                if (predicate(value, pivot)) {
                    tree = tree.match;
                } else {
                    tree = tree.notMatch;
                }
            }
        }
    };
</script>
<script>
    var data =
        [
            { "姓名": "餘夏", "年齡": 29, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" },
            { "姓名": "豆豆", "年齡": 25, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" },
            { "姓名": "帥常榮", "年齡": 26, "長相": "帥", "體型": "胖", "收入": "高", 見面: "見" },
            { "姓名": "王濤", "年齡": 22, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" },
            { "姓名": "李東", "年齡": 23, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "見" },
            { "姓名": "王五五", "年齡": 23, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "見" },
            { "姓名": "王小濤", "年齡": 22, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "見" },
            { "姓名": "李繽", "年齡": 21, "長相": "帥", "體型": "胖", "收入": "高", 見面: "見" },
            { "姓名": "劉明", "年齡": 21, "長相": "帥", "體型": "胖", "收入": "低", 見面: "不見" },
            { "姓名": "紅鶴", "年齡": 21, "長相": "不帥", "體型": "胖", "收入": "高", 見面: "不見" },
            { "姓名": "李理", "年齡": 32, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "不見" },
            { "姓名": "周州", "年齡": 31, "長相": "帥", "體型": "瘦", "收入": "高", 見面: "不見" },
            { "姓名": "李樂", "年齡": 27, "長相": "不帥", "體型": "胖", "收入": "高", 見面: "不見" },
            { "姓名": "韓明", "年齡": 24, "長相": "不帥", "體型": "瘦", "收入": "高", 見面: "不見" },
            { "姓名": "小呂", "年齡": 28, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "不見" },
            { "姓名": "李四", "年齡": 25, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "不見" },
            { "姓名": "王鵬", "年齡": 30, "長相": "帥", "體型": "瘦", "收入": "低", 見面: "不見" },
        ];
    var decisionTree = new DecisionTree();
    console.log("函式 countUniqueValues 測試:");
    console.log("   長相", decisionTree.countUniqueValues(data, "長相"));//測試
    console.log("   年齡", decisionTree.countUniqueValues(data, "年齡"));//測試
    console.log("   收入", decisionTree.countUniqueValues(data, "收入"));//測試
    console.log("函式 entropy 測試:");
    console.log("   長相", decisionTree.entropy(data, "長相"));//測試
    console.log("   年齡", decisionTree.entropy(data, "年齡"));//測試
    console.log("   收入", decisionTree.entropy(data, "收入"));//測試
    console.log("函式 mostFrequentValue 測試:");
    console.log("   年齡", decisionTree.mostFrequentValue(data, "年齡"));//測試 
    console.log("   長相", decisionTree.mostFrequentValue(data, "長相"));//測試 
    console.log("   收入", decisionTree.mostFrequentValue(data, "收入"));//測試 
    console.log("函式 split 測試:");
    console.log("   長相", decisionTree.split(data, "長相", (a, b) => { return a == b }, "不帥"));//測試
    console.log("   年齡", decisionTree.split(data, "年齡", (a, b) => { return a >= b }, 30));//測試
    console.log("   年齡", decisionTree.split(data, "年齡", (a, b) => { return a < b }, 25));//測試

    decisionTree.training({
        trainingSet: data,//訓練集
        categoryAttr: '見面',//用於區分的類別屬性 
        ignoredAttributes: ['姓名']//被忽略的屬性
    });
    // 測試決策樹與隨機森林
    var comic = { "姓名": "劉建1", "年齡": 21, "長相": "帥", "體型": "瘦", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "劉建2", "年齡": 22, "長相": "不帥", "體型": "瘦", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "劉建3", "年齡": 27, "長相": "帥", "體型": "瘦", "收入": "低" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "劉建4", "年齡": 30, "長相": "帥", "體型": "瘦", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "劉建5", "年齡": 29, "長相": "帥", "體型": "胖", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "劉建6", "年齡": 29, "長相": "帥", "體型": "胖", "收入": "低" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "劉建7", "年齡": 40, "長相": "帥", "體型": "瘦", "收入": "低" }; 
    console.log(comic,  decisionTree.predict(comic));
</script>