1. 程式人生 > >決策樹系列(五)——CART

決策樹系列(五)——CART

CART,又名分類迴歸樹,是在ID3的基礎上進行優化的決策樹,學習CART記住以下幾個關鍵點:

(1)CART既能是分類樹,又能是分類樹;

(2)當CART是分類樹時,採用GINI值作為節點分裂的依據;當CART是迴歸樹時,採用樣本的最小方差作為節點分裂的依據;

(3)CART是一棵二叉樹。

接下來將以一個實際的例子對CART進行介紹:

                                                                    表1 原始資料表

看電視時間

婚姻情況

職業

年齡

3

未婚

學生

12

4

未婚

學生

18

2

已婚

老師

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老師

29

4

已婚

學生

21

從以下的思路理解CART

分類樹?迴歸樹?

      分類樹的作用是通過一個物件的特徵來預測該物件所屬的類別,而回歸樹的目的是根據一個物件的資訊預測該物件的屬性,並以數值表示。

      CART既能是分類樹,又能是決策樹,如上表所示,如果我們想預測一個人是否已婚,那麼構建的CART將是分類樹;如果想預測一個人的年齡,那麼構建的將是迴歸樹。

分類樹和迴歸樹是怎麼做決策的?假設我們構建了兩棵決策樹分別預測使用者是否已婚和實際的年齡,如圖1和圖2所示:

                                                               

                                      圖1 預測婚姻情況決策樹                                               圖2 預測年齡的決策樹

       圖1表示一棵分類樹,其葉子節點的輸出結果為一個實際的類別,在這個例子裡是婚姻的情況(已婚或者未婚),選擇葉子節點中數量佔比最大的類別作為輸出的類別;

       圖2是一棵迴歸樹,預測使用者的實際年齡,是一個具體的輸出值。怎樣得到這個輸出值?一般情況下選擇使用中值、平均值或者眾數進行表示,圖2使用節點年齡資料的平均值作為輸出值。

CART如何選擇分裂的屬性?

      分裂的目的是為了能夠讓資料變純,使決策樹輸出的結果更接近真實值。那麼CART是如何評價節點的純度呢?如果是分類樹,CART採用GINI值衡量節點純度;如果是迴歸樹,採用樣本方差衡量節點純度。節點越不純,節點分類或者預測的效果就越差。

GINI值的計算公式:

                                                                                    

      節點越不純,GINI值越大。以二分類為例,如果節點的所有資料只有一個類別,則 ,如果兩類數量相同,則 。

迴歸方差計算公式:

                                                                       

      方差越大,表示該節點的資料越分散,預測的效果就越差。如果一個節點的所有資料都相同,那麼方差就為0,此時可以很肯定得認為該節點的輸出值;如果節點的資料相差很大,那麼輸出的值有很大的可能與實際值相差較大。

      因此,無論是分類樹還是迴歸樹,CART都要選擇使子節點的GINI值或者回歸方差最小的屬性作為分裂的方案。即最小化(分類樹):

                         

或者(迴歸樹):

                                                                                             

CART如何分裂成一棵二叉樹?

     節點的分裂分為兩種情況,連續型的資料和離散型的資料。

     CART對連續型屬性的處理與C4.5差不多,通過最小化分裂後的GINI值或者樣本方差尋找最優分割點,將節點一分為二,在這裡不再敘述,詳細請看C4.5

     對於離散型屬性,理論上有多少個離散值就應該分裂成多少個節點。但CART是一棵二叉樹,每一次分裂只會產生兩個節點,怎麼辦呢?很簡單,只要將其中一個離散值獨立作為一個節點,其他的離散值生成另外一個節點即可。這種分裂方案有多少個離散值就有多少種劃分的方法,舉一個簡單的例子:如果某離散屬性一個有三個離散值X,Y,Z,則該屬性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分別計算每種劃分方法的基尼值或者樣本方差確定最優的方法。

     以屬性“職業”為例,一共有三個離散值,“學生”、“老師”、“上班族”。該屬性有三種劃分的方案,分別為{“學生”}、{“老師”、“上班族”},{“老師”}、{“學生”、“上班族”},{“上班族”}、{“學生”、“老師”},分別計算三種劃分方案的子節點GINI值或者樣本方差,選擇最優的劃分方法,如下圖所示:

第一種劃分方法:{“學生”}、{“老師”、“上班族”}

                                   

預測是否已婚(分類):

                    

預測年齡(迴歸):

            

第二種劃分方法:{“老師”}、{“學生”、“上班族”}

                                     

預測是否已婚(分類):

                    

預測年齡(迴歸):

            

第三種劃分方法:{“上班族”}、{“學生”、“老師”}

                                      

 預測是否已婚(分類):

                    

預測年齡(迴歸):

            

綜上,如果想預測是否已婚,則選擇{“上班族”}、{“學生”、“老師”}的劃分方法,如果想預測年齡,則選擇{“老師”}、{“學生”、“上班族”}的劃分方法。

如何剪枝?

      CART採用CCP(代價複雜度)剪枝方法。代價複雜度選擇節點表面誤差率增益值最小的非葉子節點,刪除該非葉子節點的左右子節點,若有多個非葉子節點的表面誤差率增益值相同小,則選擇非葉子節點中子節點數最多的非葉子節點進行剪枝。

可描述如下:

令決策樹的非葉子節點為

a)計算所有非葉子節點的表面誤差率增益值 

b)選擇表面誤差率增益值最小的非葉子節點(若多個非葉子節點具有相同小的表面誤差率增益值,選擇節點數最多的非葉子節點)。

c)對進行剪枝

表面誤差率增益值的計算公式:

                               

其中:

表示葉子節點的誤差代價, , 為節點的錯誤率, 為節點資料量的佔比;

表示子樹的誤差代價, , 為子節點i的錯誤率, 表示節點i的資料節點佔比;

表示子樹節點個數。

算例:

下圖是其中一顆子樹,設決策樹的總資料量為40。

                                                                    

該子樹的表面誤差率增益值可以計算如下:

 

求出該子樹的表面錯誤覆蓋率為 ,只要求出其他子樹的表面誤差率增益值就可以對決策樹進行剪枝。

程式實際以及原始碼

流程圖:

                                                        

(1)資料處理

         對原始的資料進行數字化處理,並以二維資料的形式儲存,每一行表示一條記錄,前n-1列表示屬性,最後一列表示分類的標籤。

         如表1的資料可以轉化為表2:

                                                                           表2 初始化後的資料

看電視時間

婚姻情況

職業

年齡

3

未婚

學生

12

4

未婚

學生

18

2

已婚

老師

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老師

29

4

已婚

學生

21

      其中,對於“婚姻情況”屬性,數字{1,2}分別表示{未婚,已婚 };對於“職業”屬性{1,2,3, }分別表示{學生、老師、上班族};

程式碼如下所示:

         static double[][] allData;                              //儲存進行訓練的資料

    static List<String>[] featureValues;                    //離散屬性對應的離散值

featureValues是連結串列陣列,陣列的長度為屬性的個數,陣列的每個元素為該屬性的離散值連結串列。

(2)兩個類:節點類和分裂資訊

a)節點類Node

      該類表示一個節點,屬性包括節點選擇的分裂屬性、節點的輸出類、孩子節點、深度等。注意,與ID3中相比,新增了兩個屬性:leafWrong和leafNode_Count分別表示葉子節點的總分類誤差和葉子節點的個數,主要是為了方便剪枝。

 樹的節點

class Node
{
    /// <summary>
    /// 每一個節點的分裂值
    /// </summary>
    public List<String> features { get; set; }
    /// <summary>
    /// 分裂屬性的型別{離散、連續}
    /// </summary>
    public String feature_Type { get; set; }
    /// <summary>
    /// 分裂屬性的下標
    /// </summary>
    public String SplitFeature { get; set; }
    //List<int> nums = new List<int>();                       //行序號
    /// <summary>
    /// 每一個類對應的數目
    /// </summary>
    public double[] ClassCount { get; set; }
    //int[] isUsed = new int[0];                              //屬性的使用情況 1:已用 2:未用
    /// <summary>
    /// 孩子節點
    /// </summary>
    public List<Node> childNodes { get; set; }
    Node Parent = null;
    /// <summary>
    /// 該節點佔比最大的類別
    /// </summary>
    public String finalResult { get; set; }
    /// <summary>
    /// 樹的深度
    /// </summary>
    public int deep { get; set; }
    /// <summary>
    /// 最大的類下標
    /// </summary>
    public int result { get; set; }
    /// <summary>
    /// 子節點誤差
    /// </summary>
    public int leafWrong { get; set; }
    /// <summary>
    /// 子節點數目
    /// </summary>
    public int leafNode_Count { get; set; }
    /// <summary>
    /// 資料量
    /// </summary>
    public int rowCount { get; set; }

    public void setClassCount(double[] count)
    {
        this.ClassCount = count;
        double max = ClassCount[0];
        int result = 0;
        for (int i = 1; i < ClassCount.Length; i++)
        {
            if (max < ClassCount[i])
            {
                max = ClassCount[i];
                result = i;
            }
        }
        this.result = result;
    }
    public double getErrorCount()
    {
        return rowCount - ClassCount[result];
    }
}

樹的節點

b)分裂資訊類,該類儲存節點進行分裂的資訊,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的型別等。

 分裂資訊

class SplitInfo
    {
        /// <summary>
        /// 分裂的屬性下標
        /// </summary>
        public int splitIndex { get; set; }
        /// <summary>
        /// 資料型別
        /// </summary>
        public int type { get; set; }
        /// <summary>
        /// 分裂屬性的取值
        /// </summary>
        public List<String> features { get; set; }
        /// <summary>
        /// 各個節點的行座標連結串列
        /// </summary>
        public List<int>[] temp { get; set; }
        /// <summary>
        /// 每個節點各類的數目
        /// </summary>
        public double[][] class_Count { get; set; }
    }

分裂資訊

主方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂

其中:

node表示即將進行分裂的節點;

nums表示節點資料對一個的行座標列表;

isUsed表示到該節點位置所有屬性的使用情況;

findBestSplit的這個方法主要有以下幾個組成部分:

1)節點分裂停止的判定

節點分裂條件如上文所述,原始碼如下:

 停止分裂的條件

public static bool ifEnd(Node node, double shang,int[] isUsed)
        {
            try
            {
                double[] count = node.ClassCount;
                int rowCount = node.rowCount;
                int maxResult = 0;
                double maxRate = 0;
                #region 數達到某一深度
                int deep = node.deep;
                if (deep >= 10)
                {
                    maxResult = node.result + 1;
                    node.feature_Type="result";
                    node.features=new List<String>() { maxResult + "" 

};
                    node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
                    node.leafNode_Count=1;
                    return true;
                }
                #endregion
                #region 純度(其實跟後面的有點重了,記得要修改)
                //maxResult = 1;
                //for (int i = 1; i < count.Length; i++)
                //{
                //    if (count[i] / rowCount >= 0.95)
                //    {
                //        node.feature_Type="result";
                //        node.features=new List<String> { "" + (i + 

1) };
                //        node.leafNode_Count=1;
                //        node.leafWrong=rowCount - Convert.ToInt32

(count[i]);
                //        return true;
                //    }
                //}
                #endregion
                #region 熵為0
                if (shang == 0)
                {
                    maxRate = count[0] / rowCount;
                    maxResult = 1;
                    for (int i = 1; i < count.Length; i++)
                    {
                        if (count[i] / rowCount >= maxRate)
                        {
                            maxRate = count[i] / rowCount;
                            maxResult = i + 1;
                        }
                    }
                    node.feature_Type="result";
                    node.features=new List<String> { maxResult + "" 

};
                    node.leafWrong=rowCount - Convert.ToInt32(count

[maxResult - 1]);
                    node.leafNode_Count=1;
                    return true;
                }
                #endregion
                #region 屬性已經分完
                //int[] isUsed = node.getUsed();
                bool flag = true;
                for (int i = 0; i < isUsed.Length - 1; i++)
                {
                    if (isUsed[i] == 0)
                    {
                        flag = false;
                        break;
                    }
                }
                if (flag)
                {
                    maxRate = count[0] / rowCount;
                    maxResult = 1;
                    for (int i = 1; i < count.Length; i++)
                    {
                        if (count[i] / rowCount >= maxRate)
                        {
                            maxRate = count[i] / rowCount;
                            maxResult = i + 1;
                        }
                    }
                    node.feature_Type=("result");
                    node.features=(new List<String> { "" + 

(maxResult) });
                    node.leafWrong=(rowCount - Convert.ToInt32(count

[maxResult - 1]));
                    node.leafNode_Count=(1);
                    return true;
                }
                #endregion
                #region 幾點數少於100
                if (rowCount < Limit_Node)
                {
                    maxRate = count[0] / rowCount;
                    maxResult = 1;
                    for (int i = 1; i < count.Length; i++)
                    {
                        if (count[i] / rowCount >= maxRate)
                        {
                            maxRate = count[i] / rowCount;
                            maxResult = i + 1;
                        }
                    }
                    node.feature_Type="result";
                    node.features=new List<String> { "" + (maxResult) 

};
                    node.leafWrong=rowCount - Convert.ToInt32(count

[maxResult - 1]);
                    node.leafNode_Count=1;
                    return true;
                }
                #endregion
                return false;
            }
            catch (Exception e)
            {
                return false;
            }
        }

停止分裂的條件

2)尋找最優的分裂屬性

尋找最優的分裂屬性需要計算每一個分裂屬性分裂後的GINI值或者樣本方差,計算公式上文已給出,其中GINI值的計算程式碼如下:

 GINI值計算

public static double getGini(double[] counts, int countAll)
        {
            double Gini = 1;
            for (int i = 0; i < counts.Length; i++)
            {
                Gini = Gini - Math.Pow(counts[i] / countAll, 2);
            }
            return Gini;
        }

GINI值計算

3)進行分裂,同時對子節點進行迭代處理

其實就是遞迴的過程,對每一個子節點執行findBestSplit方法進行分裂。

findBestSplit原始碼:

 節點選擇屬性和分裂

public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
        {
            try
            {
                //判斷是否繼續分裂
                double totalShang = getGini(node.ClassCount, node.rowCount);
                if (ifEnd(node, totalShang, isUsed))
                {
                    return node;
                }
                #region 變數宣告
                SplitInfo info = new SplitInfo();
                info.initial();
                int RowCount = nums.Count;                  //樣本總數
                double jubuMax = 1;                         //區域性最大熵
                int splitPoint = 0;                         //分裂的點
                double splitValue = 0;                      //分裂的值
                #endregion
                for (int i = 0; i < isUsed.Length - 1; i++)
                {
                    if (isUsed[i] == 1)
                    {
                        continue;
                    }
                    #region 離散變數
                    if (type[i] == 0)
                    {
                        double[][] allCount = new double[allNum[i]][];
                        for (int j = 0; j < allCount.Length; j++)
                        {
                            allCount[j] = new double[classCount];
                        }
                        int[] countAllFeature = new int[allNum[i]];
                        List<int>[] temp = new List<int>[allNum[i]];
                        double[] allClassCount = node.ClassCount;     //所有類別的數量
                        for (int j = 0; j < temp.Length; j++)
                        {
                            temp[j] = new List<int>();
                        }
                        for (int j = 0; j < nums.Count; j++)
                        {
                            int index = Convert.ToInt32(allData[nums[j]][i]);
                            temp[index - 1].Add(nums[j]);
                            countAllFeature[index - 1]++;
                            allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
                        }
                        double allShang = 1;
                        int choose = 0;

                        double[][] jubuCount = new double[2][];
                        for (int k = 0; k < allCount.Length; k++)
                        {
                            if (temp[k].Count == 0)
                                continue;
                            double JubuShang = 0;
                            double[][] tempCount = new double[2][];
                            tempCount[0] = allCount[k];
                            tempCount[1] = new double[allCount[0].Length];
                            for (int j = 0; j < tempCount[1].Length; j++)
                            {
                                tempCount[1][j] = allClassCount[j] - allCount[k][j];
                            }
                            JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
                            int nodecount = RowCount - countAllFeature[k];
                            JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
                            if (JubuShang < allShang)
                            {
                                allShang = JubuShang;
                                jubuCount = tempCount;
                                choose = k;
                            }
                        }                        
                        if (allShang < jubuMax)
                        {
                            info.type = 0;
                            jubuMax = allShang;
                            info.class_Count = jubuCount;
                            info.temp[0] = temp[choose];
                            info.temp[1] = new List<int>();
                            info.features = new List<string>();
                            info.features.Add((choose + 1) + "");
                            info.features.Add("");
                            for (int j = 0; j < temp.Length; j++)
                            {
                                if (j == choose)
                                    continue;
                                for (int k = 0; k < temp[j].Count; k++)
                                {
                                    info.temp[1].Add(temp[j][k]);
                                }
                                if (temp[j].Count != 0)
                                {
                                    info.features[1] = info.features[1] + (j + 1) + ",";
                                }
                            }
                            info.splitIndex = i;
                        }
                    }
                    #endregion
                    #region 連續變數
                    else
                    {
                        double[] leftCunt = new double[classCount];   

          //做節點各個類別的數量
                        double[] rightCount = new double[classCount]; 

          //右節點各個類別的數量
                        double[] count1 = new double[classCount];     

          //子集1的統計量
                        double[] count2 = new double

[node.ClassCount.Length];   //子集2的統計量
                        for (int j = 0; j < node.ClassCount.Length; 

j++)
                        {
                            count2[j] = node.ClassCount[j];
                        }
                        int all1 = 0;                                 

          //子集1的樣本量
                        int all2 = nums.Count;                        

          //子集2的樣本量
                        double lastValue = 0;                         

         //上一個記錄的類別
                        double currentValue = 0;                      

         //當前類別
                        double lastPoint = 0;                         

          //上一個點的值
                        double currentPoint = 0;                      

          //當前點的值
                        double[] values = new double[nums.Count];
                        for (int j = 0; j < values.Length; j++)
                        {
                            values[j] = allData[nums[j]][i];
                        }
                        QSort(values, nums, 0, nums.Count - 1);
                        double lianxuMax = 1;                         

          //連續型屬性的最大熵
                        #region 尋找最佳的分割點
                        for (int j = 0; j < nums.Count - 1; j++)
                        {
                            currentValue = allData[nums[j]][lieshu - 

1];
                            currentPoint = (allData[nums[j]][i]);
                            if (j == 0)
                            {
                                lastValue = currentValue;
                                lastPoint = currentPoint;
                            }
                            if (currentValue != lastValue && 

currentPoint != lastPoint)
                            {
                                double shang1 = getGini(count1, 

all1);
                                double shang2 = getGini(count2, 

all2);
                                double allShang = shang1 * all1 / 

(all1 + all2) + shang2 * all2 / (all1 + all2);
                                //allShang = (totalShang - allShang);
                                if (lianxuMax > allShang)
                                {
                                    lianxuMax = allShang;
                                    for (int k = 0; k < 

count1.Length; k++)
                                    {
                                        leftCunt[k] = count1[k];
                                        rightCount[k] = count2[k];
                                    }
                                    splitPoint = j;
                                    splitValue = (currentPoint + 

lastPoint) / 2;
                                }
                            }
                            all1++;
                            count1[Convert.ToInt32(currentValue) - 

1]++;
                            count2[Convert.ToInt32(currentValue) - 

1]--;
                            all2--;
                            lastValue = currentValue;
                            lastPoint = currentPoint;
                        }
                        #endregion
                        #region 如果超過了區域性值,重設
                        if (lianxuMax < jubuMax)
                        {
                            info.type = 1;
                            info.splitIndex = i;
                            info.features=new List<string>()

{splitValue+""};
                            //finalPoint = splitPoint;
                            jubuMax = lianxuMax;
                            info.temp[0] = new List<int>();
                            info.temp[1] = new List<int>();
                            for (int k = 0; k < splitPoint; k++)
                            {
                                info.temp[0].Add(nums[k]);
                            }
                            for (int k = splitPoint; k < nums.Count; 

k++)
                            {
                                info.temp[1].Add(nums[k]);
                            }
                            info.class_Count[0] = new double

[leftCunt.Length];
                            info.class_Count[1] = new double

[leftCunt.Length];
                            for (int k = 0; k < leftCunt.Length; k++)
                            {
                                info.class_Count[0][k] = leftCunt[k];
                                info.class_Count[1][k] = rightCount

[k];
                            }
                        }
                        #endregion
                    }
                    #endregion
                }
                #region 沒有尋找到最佳的分裂點,則設定為葉節點
                if (info.splitIndex == -1)
                {
                    double[] finalCount = node.ClassCount;
                    double max = finalCount[0];
                    int result = 1;
                    for (int i = 1; i < finalCount.Length; i++)
                    {
                        if (finalCount[i] > max)
                        {
                            max = finalCount[i];
                            result = (i + 1);
                        }
                    }
                    node.feature_Type="result";
                    node.features=new List<String> { "" + result };
                    return node;
                }
                #endregion
                #region 分裂
                int deep = node.deep;
                node.SplitFeature = ("" + info.splitIndex);
                List<Node> childNode = new List<Node>();
                int[][] used = new int[2][];
                used[0] = new int[isUsed.Length];
                used[1] = new int[isUsed.Length];
                for (int i = 0; i < isUsed.Length; i++)
                {
                    used[0][i] = isUsed[i];
                    used[1][i] = isUsed[i];
                }
                if (info.type == 0)
                {
                    used[0][info.splitIndex] = 1;
                    node.feature_Type = ("離散");
                }
                else
                {
                    //used[info.splitIndex] = 0;
                    node.feature_Type = ("連續");
                }
                List<int>[] rowIndex = info.temp;
                List<String> features = info.features;
                Node node1 = new Node();
                Node node2 = new Node();
                node1.setClassCount(info.class_Count[0]);
                node2.setClassCount(info.class_Count[1]);
                node1.rowCount = info.temp[0].Count;
                node2.rowCount = info.temp[1].Count;
                node1.deep = deep + 1;
                node2.deep = deep + 1;
                node1 = findBestSplit(node1, info.temp[0],used[0]);
                node2 = findBestSplit(node2, info.temp[1], used[1]);
                node.leafNode_Count = (node1.leafNode_Count

+node2.leafNode_Count);
                node.leafWrong = (node1.leafWrong+node2.leafWrong);
                node.features = (features);
                childNode.Add(node1);
                childNode.Add(node2);
                node.childNodes = childNode;
                #endregion
                return node;
            }
            catch (Exception e)
            {
                Console.WriteLine(e.StackTrace);
                return node;
            }
        }

節點選擇屬性和分裂

(4)剪枝

代價複雜度剪枝方法(CCP):

 CCP代價複雜度剪枝

public static void getSeries(Node node)
        {
            Stack<Node> nodeStack = new Stack<Node>();
            if (node != null)
            {
                nodeStack.Push(node);
            }
            if (node.feature_Type == "result")
                return;
            List<Node> childs = node.childNodes;
            for (int i = 0; i < childs.Count; i++)
            {
                getSeries(node);
            }
        }

CCP代價複雜度剪枝

CART全部核心程式碼:

 CART核心程式碼

/// <summary>
        /// 判斷是否還需要分裂
        /// </summary>
        /// <param name="node"></param>
        /// <returns></returns>
        public static bool ifEnd(Node node, double shang,int[] isUsed)
        {
            try
            {
                double[] count = node.ClassCount;
                int rowCount = node.rowCount;
                int maxResult = 0;
                double maxRate = 0;
                #region 數達到某一深度
                int deep = node.deep;
                if (deep >= 10)
                {
                    maxResult = node.result + 1;
                    node.feature_Type="result";
                    node.features=new List<String>() { maxResult + "" 

};
                    node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
                    node.leafNode_Count=1;
                    return true;
                }
                #endregion
                #region 純度(其實跟後面的有點重了,記得要修改)
                //maxResult = 1;
                //for (int i = 1; i < count.Length; i++)
                //{
                //    if (count[i] / rowCount >= 0.95)
                //    {
                //        node.feature_Type="result";
                //        node.features=new List<String> { "" + (i + 

1) };
                //        node.leafNode_Count=1;
                //        node.leafWrong=rowCount - Convert.ToInt32

(count[i]);
                //        return true;
                //    }
                //}
                #endregion
                #region 熵為0
                if (shang == 0)
                {
                    maxRate = count[0] / rowCount;
                    maxResult = 1;
                    for (int i = 1; i < count.Length; i++)
                    {
                        if (count[i] / rowCount >= maxRate)
                        {
                            maxRate = count[i] / rowCount;
                            maxResult = i + 1;
                        }
                    }
                    node.feature_Type="result";
                    node.features=new List<String> { maxResult + "" 

};
                    node.leafWrong=rowCount - Convert.ToInt32(count

[maxResult - 1]);
                    node.leafNode_Count=1;
                    return true;
                }
                #endregion
                #region 屬性已經分完
                //int[] isUsed = node.getUsed();
                bool flag = true;
                for (int i = 0; i < isUsed.Length - 1; i++)
                {
                    if (isUsed[i] == 0)
                    {
                        flag = false;
                        break;
                    }
                }
                if (flag)
                {
                    maxRate = count[0] / rowCount;
                    maxResult = 1;
                    for (int i = 1; i < count.Length; i++)
                    {
                        if (count[i] / rowCount >= maxRate)
                        {
                            maxRate = count[i] / rowCount;
                            maxResult = i + 1;
                        }
                    }
                    node.feature_Type=("result");
                    node.features=(new List<String> { "" + 

(maxResult) });
                    node.leafWrong=(rowCount - Convert.ToInt32(count

[maxResult - 1]));
                    node.leafNode_Count=(1);
                    return true;
                }
                #endregion
                #region 幾點數少於100
                if (rowCount < Limit_Node)
                {
                    maxRate = count[0] / rowCount;
                    maxResult = 1;
                    for (int i = 1; i < count.Length; i++)
                    {
                        if (count[i] / rowCount >= maxRate)
                        {
                            maxRate = count[i] / rowCount;
                            maxResult = i + 1;
                        }
                    }
                    node.feature_Type="result";
                    node.features=new List<String> { "" + (maxResult) 

};
                    node.leafWrong=rowCount - Convert.ToInt32(count

[maxResult - 1]);
                    node.leafNode_Count=1;
                    return true;
                }
                #endregion
                return false;
            }
            catch (Exception e)
            {
                return false;
            }
        }
        #region 排序演算法
        public static void InsertSort(double[] values, List<int> arr, 

int StartIndex, int endIndex)
        {
            for (int i = StartIndex + 1; i <= endIndex; i++)
            {
                int key = arr[i];
                double init = values[i];
                int j = i - 1;
                while (j >= StartIndex && values[j] > init)
                {
                    arr[j + 1] = arr[j];
                    values[j + 1] = values[j];
                    j--;
                }
                arr[j + 1] = key;
                values[j + 1] = init;
            }
        }
        static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
        {
            int mid = low + ((high - low) >> 1);//計算陣列中間的元素的下標  

            //使用三數取中法選擇樞軸  
            if (values[mid] > values[high])//目標: arr[mid] <= arr[high]  
            {
                swap(values, arr, mid, high);
            }
            if (values[low] > values[high])//目標: arr[low] <= arr[high]  
            {
                swap(values, arr, low, high);
            }
            if (values[mid] > values[low]) //目標: arr[low] >= arr[mid]  
            {
                swap(values, arr, mid, low);
            }
            //此時,arr[mid] <= arr[low] <= arr[high]  
            return low;
            //low的位置上儲存這三個位置中間的值  
            //分割時可以直接使用low位置的元素作為樞軸,而不用改變分割函數了  
        }
        static void swap(double[] values, List<int> arr, int t1, int t2)
        {
            double temp = values[t1];
            values[t1] = values[t2];
            values[t2] = temp;
            int key = arr[t1];
            arr[t1] = arr[t2];
            arr[t2] = key;
        }
        static void QSort(double[] values, List<int> arr, int low, int high)
        {
            int first = low;
            int last = high;

            int left = low;
            int right = high;

            int leftLen = 0;
            int rightLen = 0;

            if (high - low + 1 < 10)
            {
                InsertSort(values, arr, low, high);
                return;
            }

            //一次分割 
            int key = SelectPivotMedianOfThree(values, arr, low, 

high);//使用三數取中法選擇樞軸 
            double inti = values[key];
            int currentKey = arr[key];

            while (low < high)
            {
                while (high > low && values[high] >= inti)
                {
                    if (values[high] == inti)//處理相等元素  
                    {
                        swap(values, arr, right, high);
                        right--;
                        rightLen++;
                    }
                    high--;
                }
                arr[low] = arr[high];
                values[low] = values[high];
                while (high > low && values[low] <= inti)
                {
                    if (values[low] == inti)
                    {
                        swap(values, arr, left, low);
                        left++;
                        leftLen++;
                    }
                    low++;
                }
                arr[high] = arr[low];
                values[high] = values[low];
            }
            arr[low] = currentKey;
            values[low] = values[key];
            //一次快排結束  
            //把與樞軸key相同的元素移到樞軸最終位置周圍  
            int i = low - 1;
            int j = first;
            while (j < left && values[i] != inti)
            {
                swap(values, arr, i, j);
                i--;
                j++;
            }
            i = low + 1;
            j = last;
            while (j > right && values[i] != inti)
            {
                swap(values, arr, i, j);
                i++;
                j--;
            }
            QSort(values, arr, first, low - 1 - leftLen);
            QSort(values, arr, low + 1 + rightLen, last);
        }
        #endregion
        /// <summary>
        /// 尋找最佳的分裂點
        /// </summary>
        /// <param name="num"></param>
        /// <param name="node"></param>
        public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
        {
            try
            {
                //判斷是否繼續分裂
                double totalShang = getGini(node.ClassCount, node.rowCount);
                if (ifEnd(node, totalShang, isUsed))
                {
                    return node;
                }
                #region 變數宣告
                SplitInfo info = new SplitInfo();
                info.initial();
                int RowCount = nums.Count;                  //樣本總數
                double jubuMax = 1;                         //區域性最大熵
                int splitPoint = 0;                         //分裂的點
                double splitValue = 0;                      //分裂的值
                #endregion
                for (int i = 0; i < isUsed.Length - 1; i++)
                {
                    if (isUsed[i] == 1)
                    {
                        continue;
                    }
                    #region 離散變數
                    if (type[i] == 0)
                    {
                        double[][] allCount = new double[allNum[i]][];
                        for (int j = 0; j < allCount.Length; j++)
                        {
                            allCount[j] = new double[classCount];
                        }
                        int[] countAllFeature = new int[allNum[i]];
                        List<int>[] temp = new List<int>[allNum[i]];
                        double[] allClassCount = node.ClassCount;     //所有類別的數量
                        for (int j = 0; j < temp.Length; j++)
                        {
                            temp[j] = new List<int>();
                        }
                        for (int j = 0; j < nums.Count; j++)
                        {
                            int index = Convert.ToInt32(allData[nums[j]][i]);
                            temp[index - 1].Add(nums[j]);
                            countAllFeature[index - 1]++;
                            allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
                        }
                        double allShang = 1;
                        int choose = 0;

                        double[][] jubuCount = new double[2][];
                        for (int k = 0; k < allCount.Length; k++)
                        {
                            if (temp[k].Count == 0)
                                continue;
                            double JubuShang = 0;
                            double[][] tempCount = new double[2][];
                            tempCount[0] = allCount[k];
                            tempCount[1] = new double[allCount[0].Length];
                            for (int j = 0; j < tempCount[1].Length; j++)
                            {
                                tempCount[1][j] = allClassCount[j] - allCount[k][j];
                            }
                            JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
                            int nodecount = RowCount - countAllFeature[k];
                            JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
                            if (JubuShang < allShang)
                            {
                                allShang = JubuShang;
                                jubuCount = tempCount;
                                choose = k;
                            }
                        }                        
                        if (allShang < jubuMax)
                        {
                            info.type = 0;
                            jubuMax = allShang;
                            info.class_Count = jubuCount;
                            info.temp[0] = temp[choose];
                            info.temp[1] = new List<int>();
                            info.features = new List<string>();
                            info.features.Add((choose + 1) + "");
                            info.features.Add("");
                            for (int j = 0; j < temp.Length; j++)
                            {
                                if (j == choose)
                                    continue;
                                for (int k = 0; k < temp[j].Count; k++)
                                {
                                    info.temp[1].Add(temp[j][k]);
                                }
                                if (temp[j].Count != 0)
                                {
                                    info.features[1] = info.features[1] + (j + 1) + ",";
                                }
                            }
                            info.splitIndex = i;
                        }
                    }
                    #endregion
                    #region 連續變數
                    else
                    {
                        double[] leftCunt = new double[classCount];   

          //做節點各個類別的數量
                        double[] rightCount = new double[classCount]; 

          //右節點各個類別的數量
                        double[] count1 = new double[classCount];     

          //子集1的統計量
                        double[] count2 = new double

[node.ClassCount.Length];   //子集2的統計量
                        for (int j = 0; j < node.ClassCount.Length; 

j++)
                        {
                            count2[j] = node.ClassCount[j];
                        }
                        int all1 = 0;                                 

          //子集1的樣本量
                        int all2 = nums.Count;                        

          //子集2的樣本量
                        double lastValue = 0;                         

         //上一個記錄的類別
                        double currentValue = 0;                      

         //當前類別
                        double lastPoint = 0;                         

          //上一個點的值
                        double currentPoint = 0;                      

          //當前點的值
                        double[] values = new double[nums.Count];
                        for (int j = 0; j < values.Length; j++)
                        {
                            values[j] = allData[nums[j]][i];
                        }
                        QSort(values, nums, 0, nums.Count - 1);
                        double lianxuMax = 1;                         

          //連續型屬性的最大熵
                        #region 尋找最佳的分割點
                        for (int j = 0; j < nums.Count - 1; j++)
                        {
                            currentValue = allData[nums[j]][lieshu - 

1];
                            currentPoint = (allData[nums[j]][i]);
                            if (j == 0)
                            {
                                lastValue = currentValue;
                                lastPoint = currentPoint;
                            }
                            if (currentValue != lastValue && 

currentPoint != lastPoint)
                            {
                                double shang1 = getGini(count1, 

all1);
                                double shang2 = getGini(count2, 

all2);
                                double allShang = shang1 * all1 / 

(all1 + all2) + shang2 * all2 / (all1 + all2);
                                //allShang = (totalShang - allShang);
                                if (lianxuMax > allShang)
                                {
                                    lianxuMax = allShang;
                                    for (int k = 0; k < 

count1.Length; k++)
                                    {
                                        leftCunt[k] = count1[k];
                                        rightCount[k] = count2[k];
                                    }
                                    splitPoint = j;
                                    splitValue = (currentPoint + 

lastPoint) / 2;
                                }
                            }
                            all1++;
                            count1[Convert.ToInt32(currentValue) - 

1]++;
                            count2[Convert.ToInt32(currentValue) - 

1]--;
                            all2--;
                            lastValue = currentValue;
                            lastPoint = currentPoint;
                        }
                        #endregion
                        #region 如果超過了區域性值,重設
                        if (lianxuMax < jubuMax)
                        {
                            info.type = 1;
                            info.splitIndex = i;
                            info.features=new List<string>()

{splitValue+""};
                            //finalPoint = splitPoint;
                            jubuMax = lianxuMax;
                            info.temp[0] = new List<int>();
                            info.temp[1] = new List<int>();
                            for (int k = 0; k < splitPoint; k++)
                            {
                                info.temp[0].Add(nums[k]);
                            }
                            for (int k = splitPoint; k < nums.Count; 

k++)
                            {
                                info.temp[1].Add(nums[k]);
                            }
                            info.class_Count[0] = new double

[leftCunt.Length];
                            info.class_Count[1] = new double

[leftCunt.Length];
                            for (int k = 0; k < leftCunt.Length; k++)
                            {
                                info.class_Count[0][k] = leftCunt[k];
                                info.class_Count[1][k] = rightCount

[k];
                            }
                        }
                        #endregion
                    }
                    #endregion
                }
                #region 沒有尋找到最佳的分裂點,則設定為葉節點
                if (info.splitIndex == -1)
                {
                    double[] finalCount = node.ClassCount;
                    double max = finalCount[0];
                    int result = 1;
                    for (int i = 1; i < finalCount.Length; i++)
                    {
                        if (finalCount[i] > max)
                        {
                            max = finalCount[i];
                            result = (i + 1);
                        }
                    }
                    node.feature_Type="result";
                    node.features=new List<String> { "" + result };
                    return node;
                }
                #endregion
                #region 分裂
                int deep = node.deep;
                node.SplitFeature = ("" + info.spli