決策樹系列(五)——CART
CART,又名分類迴歸樹,是在ID3的基礎上進行優化的決策樹,學習CART記住以下幾個關鍵點:
(1)CART既能是分類樹,又能是分類樹;
(2)當CART是分類樹時,採用GINI值作為節點分裂的依據;當CART是迴歸樹時,採用樣本的最小方差作為節點分裂的依據;
(3)CART是一棵二叉樹。
接下來將以一個實際的例子對CART進行介紹:
表1 原始資料表
看電視時間 |
婚姻情況 |
職業 |
年齡 |
3 |
未婚 |
學生 |
12 |
4 |
未婚 |
學生 |
18 |
2 |
已婚 |
老師 |
26 |
5 |
已婚 |
上班族 |
47 |
2.5 |
已婚 |
上班族 |
36 |
3.5 |
未婚 |
老師 |
29 |
4 |
已婚 |
學生 |
21 |
從以下的思路理解CART:
分類樹?迴歸樹?
分類樹的作用是通過一個物件的特徵來預測該物件所屬的類別,而回歸樹的目的是根據一個物件的資訊預測該物件的屬性,並以數值表示。
CART既能是分類樹,又能是決策樹,如上表所示,如果我們想預測一個人是否已婚,那麼構建的CART將是分類樹;如果想預測一個人的年齡,那麼構建的將是迴歸樹。
分類樹和迴歸樹是怎麼做決策的?假設我們構建了兩棵決策樹分別預測使用者是否已婚和實際的年齡,如圖1和圖2所示:
圖1 預測婚姻情況決策樹 圖2 預測年齡的決策樹
圖1表示一棵分類樹,其葉子節點的輸出結果為一個實際的類別,在這個例子裡是婚姻的情況(已婚或者未婚),選擇葉子節點中數量佔比最大的類別作為輸出的類別;
圖2是一棵迴歸樹,預測使用者的實際年齡,是一個具體的輸出值。怎樣得到這個輸出值?一般情況下選擇使用中值、平均值或者眾數進行表示,圖2使用節點年齡資料的平均值作為輸出值。
CART如何選擇分裂的屬性?
分裂的目的是為了能夠讓資料變純,使決策樹輸出的結果更接近真實值。那麼CART是如何評價節點的純度呢?如果是分類樹,CART採用GINI值衡量節點純度;如果是迴歸樹,採用樣本方差衡量節點純度。節點越不純,節點分類或者預測的效果就越差。
GINI值的計算公式:
節點越不純,GINI值越大。以二分類為例,如果節點的所有資料只有一個類別,則 ,如果兩類數量相同,則 。
迴歸方差計算公式:
方差越大,表示該節點的資料越分散,預測的效果就越差。如果一個節點的所有資料都相同,那麼方差就為0,此時可以很肯定得認為該節點的輸出值;如果節點的資料相差很大,那麼輸出的值有很大的可能與實際值相差較大。
因此,無論是分類樹還是迴歸樹,CART都要選擇使子節點的GINI值或者回歸方差最小的屬性作為分裂的方案。即最小化(分類樹):
或者(迴歸樹):
CART如何分裂成一棵二叉樹?
節點的分裂分為兩種情況,連續型的資料和離散型的資料。
CART對連續型屬性的處理與C4.5差不多,通過最小化分裂後的GINI值或者樣本方差尋找最優分割點,將節點一分為二,在這裡不再敘述,詳細請看C4.5。
對於離散型屬性,理論上有多少個離散值就應該分裂成多少個節點。但CART是一棵二叉樹,每一次分裂只會產生兩個節點,怎麼辦呢?很簡單,只要將其中一個離散值獨立作為一個節點,其他的離散值生成另外一個節點即可。這種分裂方案有多少個離散值就有多少種劃分的方法,舉一個簡單的例子:如果某離散屬性一個有三個離散值X,Y,Z,則該屬性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分別計算每種劃分方法的基尼值或者樣本方差確定最優的方法。
以屬性“職業”為例,一共有三個離散值,“學生”、“老師”、“上班族”。該屬性有三種劃分的方案,分別為{“學生”}、{“老師”、“上班族”},{“老師”}、{“學生”、“上班族”},{“上班族”}、{“學生”、“老師”},分別計算三種劃分方案的子節點GINI值或者樣本方差,選擇最優的劃分方法,如下圖所示:
第一種劃分方法:{“學生”}、{“老師”、“上班族”}
預測是否已婚(分類):
預測年齡(迴歸):
第二種劃分方法:{“老師”}、{“學生”、“上班族”}
預測是否已婚(分類):
預測年齡(迴歸):
第三種劃分方法:{“上班族”}、{“學生”、“老師”}
預測是否已婚(分類):
預測年齡(迴歸):
綜上,如果想預測是否已婚,則選擇{“上班族”}、{“學生”、“老師”}的劃分方法,如果想預測年齡,則選擇{“老師”}、{“學生”、“上班族”}的劃分方法。
如何剪枝?
CART採用CCP(代價複雜度)剪枝方法。代價複雜度選擇節點表面誤差率增益值最小的非葉子節點,刪除該非葉子節點的左右子節點,若有多個非葉子節點的表面誤差率增益值相同小,則選擇非葉子節點中子節點數最多的非葉子節點進行剪枝。
可描述如下:
令決策樹的非葉子節點為。
a)計算所有非葉子節點的表面誤差率增益值
b)選擇表面誤差率增益值最小的非葉子節點(若多個非葉子節點具有相同小的表面誤差率增益值,選擇節點數最多的非葉子節點)。
c)對進行剪枝
表面誤差率增益值的計算公式:
其中:
表示葉子節點的誤差代價, , 為節點的錯誤率, 為節點資料量的佔比;
表示子樹的誤差代價, , 為子節點i的錯誤率, 表示節點i的資料節點佔比;
表示子樹節點個數。
算例:
下圖是其中一顆子樹,設決策樹的總資料量為40。
該子樹的表面誤差率增益值可以計算如下:
求出該子樹的表面錯誤覆蓋率為 ,只要求出其他子樹的表面誤差率增益值就可以對決策樹進行剪枝。
程式實際以及原始碼
流程圖:
(1)資料處理
對原始的資料進行數字化處理,並以二維資料的形式儲存,每一行表示一條記錄,前n-1列表示屬性,最後一列表示分類的標籤。
如表1的資料可以轉化為表2:
表2 初始化後的資料
看電視時間 |
婚姻情況 |
職業 |
年齡 |
3 |
未婚 |
學生 |
12 |
4 |
未婚 |
學生 |
18 |
2 |
已婚 |
老師 |
26 |
5 |
已婚 |
上班族 |
47 |
2.5 |
已婚 |
上班族 |
36 |
3.5 |
未婚 |
老師 |
29 |
4 |
已婚 |
學生 |
21 |
其中,對於“婚姻情況”屬性,數字{1,2}分別表示{未婚,已婚 };對於“職業”屬性{1,2,3, }分別表示{學生、老師、上班族};
程式碼如下所示:
static double[][] allData; //儲存進行訓練的資料
static List<String>[] featureValues; //離散屬性對應的離散值
featureValues是連結串列陣列,陣列的長度為屬性的個數,陣列的每個元素為該屬性的離散值連結串列。
(2)兩個類:節點類和分裂資訊
a)節點類Node
該類表示一個節點,屬性包括節點選擇的分裂屬性、節點的輸出類、孩子節點、深度等。注意,與ID3中相比,新增了兩個屬性:leafWrong和leafNode_Count分別表示葉子節點的總分類誤差和葉子節點的個數,主要是為了方便剪枝。
樹的節點
class Node
{
/// <summary>
/// 每一個節點的分裂值
/// </summary>
public List<String> features { get; set; }
/// <summary>
/// 分裂屬性的型別{離散、連續}
/// </summary>
public String feature_Type { get; set; }
/// <summary>
/// 分裂屬性的下標
/// </summary>
public String SplitFeature { get; set; }
//List<int> nums = new List<int>(); //行序號
/// <summary>
/// 每一個類對應的數目
/// </summary>
public double[] ClassCount { get; set; }
//int[] isUsed = new int[0]; //屬性的使用情況 1:已用 2:未用
/// <summary>
/// 孩子節點
/// </summary>
public List<Node> childNodes { get; set; }
Node Parent = null;
/// <summary>
/// 該節點佔比最大的類別
/// </summary>
public String finalResult { get; set; }
/// <summary>
/// 樹的深度
/// </summary>
public int deep { get; set; }
/// <summary>
/// 最大的類下標
/// </summary>
public int result { get; set; }
/// <summary>
/// 子節點誤差
/// </summary>
public int leafWrong { get; set; }
/// <summary>
/// 子節點數目
/// </summary>
public int leafNode_Count { get; set; }
/// <summary>
/// 資料量
/// </summary>
public int rowCount { get; set; }
public void setClassCount(double[] count)
{
this.ClassCount = count;
double max = ClassCount[0];
int result = 0;
for (int i = 1; i < ClassCount.Length; i++)
{
if (max < ClassCount[i])
{
max = ClassCount[i];
result = i;
}
}
this.result = result;
}
public double getErrorCount()
{
return rowCount - ClassCount[result];
}
}
樹的節點
b)分裂資訊類,該類儲存節點進行分裂的資訊,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的型別等。
分裂資訊
class SplitInfo
{
/// <summary>
/// 分裂的屬性下標
/// </summary>
public int splitIndex { get; set; }
/// <summary>
/// 資料型別
/// </summary>
public int type { get; set; }
/// <summary>
/// 分裂屬性的取值
/// </summary>
public List<String> features { get; set; }
/// <summary>
/// 各個節點的行座標連結串列
/// </summary>
public List<int>[] temp { get; set; }
/// <summary>
/// 每個節點各類的數目
/// </summary>
public double[][] class_Count { get; set; }
}
分裂資訊
主方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂
其中:
node表示即將進行分裂的節點;
nums表示節點資料對一個的行座標列表;
isUsed表示到該節點位置所有屬性的使用情況;
findBestSplit的這個方法主要有以下幾個組成部分:
1)節點分裂停止的判定
節點分裂條件如上文所述,原始碼如下:
停止分裂的條件
public static bool ifEnd(Node node, double shang,int[] isUsed)
{
try
{
double[] count = node.ClassCount;
int rowCount = node.rowCount;
int maxResult = 0;
double maxRate = 0;
#region 數達到某一深度
int deep = node.deep;
if (deep >= 10)
{
maxResult = node.result + 1;
node.feature_Type="result";
node.features=new List<String>() { maxResult + ""
};
node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
node.leafNode_Count=1;
return true;
}
#endregion
#region 純度(其實跟後面的有點重了,記得要修改)
//maxResult = 1;
//for (int i = 1; i < count.Length; i++)
//{
// if (count[i] / rowCount >= 0.95)
// {
// node.feature_Type="result";
// node.features=new List<String> { "" + (i +
1) };
// node.leafNode_Count=1;
// node.leafWrong=rowCount - Convert.ToInt32
(count[i]);
// return true;
// }
//}
#endregion
#region 熵為0
if (shang == 0)
{
maxRate = count[0] / rowCount;
maxResult = 1;
for (int i = 1; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + 1;
}
}
node.feature_Type="result";
node.features=new List<String> { maxResult + ""
};
node.leafWrong=rowCount - Convert.ToInt32(count
[maxResult - 1]);
node.leafNode_Count=1;
return true;
}
#endregion
#region 屬性已經分完
//int[] isUsed = node.getUsed();
bool flag = true;
for (int i = 0; i < isUsed.Length - 1; i++)
{
if (isUsed[i] == 0)
{
flag = false;
break;
}
}
if (flag)
{
maxRate = count[0] / rowCount;
maxResult = 1;
for (int i = 1; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + 1;
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" +
(maxResult) });
node.leafWrong=(rowCount - Convert.ToInt32(count
[maxResult - 1]));
node.leafNode_Count=(1);
return true;
}
#endregion
#region 幾點數少於100
if (rowCount < Limit_Node)
{
maxRate = count[0] / rowCount;
maxResult = 1;
for (int i = 1; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + 1;
}
}
node.feature_Type="result";
node.features=new List<String> { "" + (maxResult)
};
node.leafWrong=rowCount - Convert.ToInt32(count
[maxResult - 1]);
node.leafNode_Count=1;
return true;
}
#endregion
return false;
}
catch (Exception e)
{
return false;
}
}
停止分裂的條件
2)尋找最優的分裂屬性
尋找最優的分裂屬性需要計算每一個分裂屬性分裂後的GINI值或者樣本方差,計算公式上文已給出,其中GINI值的計算程式碼如下:
GINI值計算
public static double getGini(double[] counts, int countAll)
{
double Gini = 1;
for (int i = 0; i < counts.Length; i++)
{
Gini = Gini - Math.Pow(counts[i] / countAll, 2);
}
return Gini;
}
GINI值計算
3)進行分裂,同時對子節點進行迭代處理
其實就是遞迴的過程,對每一個子節點執行findBestSplit方法進行分裂。
findBestSplit原始碼:
節點選擇屬性和分裂
public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
{
try
{
//判斷是否繼續分裂
double totalShang = getGini(node.ClassCount, node.rowCount);
if (ifEnd(node, totalShang, isUsed))
{
return node;
}
#region 變數宣告
SplitInfo info = new SplitInfo();
info.initial();
int RowCount = nums.Count; //樣本總數
double jubuMax = 1; //區域性最大熵
int splitPoint = 0; //分裂的點
double splitValue = 0; //分裂的值
#endregion
for (int i = 0; i < isUsed.Length - 1; i++)
{
if (isUsed[i] == 1)
{
continue;
}
#region 離散變數
if (type[i] == 0)
{
double[][] allCount = new double[allNum[i]][];
for (int j = 0; j < allCount.Length; j++)
{
allCount[j] = new double[classCount];
}
int[] countAllFeature = new int[allNum[i]];
List<int>[] temp = new List<int>[allNum[i]];
double[] allClassCount = node.ClassCount; //所有類別的數量
for (int j = 0; j < temp.Length; j++)
{
temp[j] = new List<int>();
}
for (int j = 0; j < nums.Count; j++)
{
int index = Convert.ToInt32(allData[nums[j]][i]);
temp[index - 1].Add(nums[j]);
countAllFeature[index - 1]++;
allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
}
double allShang = 1;
int choose = 0;
double[][] jubuCount = new double[2][];
for (int k = 0; k < allCount.Length; k++)
{
if (temp[k].Count == 0)
continue;
double JubuShang = 0;
double[][] tempCount = new double[2][];
tempCount[0] = allCount[k];
tempCount[1] = new double[allCount[0].Length];
for (int j = 0; j < tempCount[1].Length; j++)
{
tempCount[1][j] = allClassCount[j] - allCount[k][j];
}
JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
int nodecount = RowCount - countAllFeature[k];
JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
if (JubuShang < allShang)
{
allShang = JubuShang;
jubuCount = tempCount;
choose = k;
}
}
if (allShang < jubuMax)
{
info.type = 0;
jubuMax = allShang;
info.class_Count = jubuCount;
info.temp[0] = temp[choose];
info.temp[1] = new List<int>();
info.features = new List<string>();
info.features.Add((choose + 1) + "");
info.features.Add("");
for (int j = 0; j < temp.Length; j++)
{
if (j == choose)
continue;
for (int k = 0; k < temp[j].Count; k++)
{
info.temp[1].Add(temp[j][k]);
}
if (temp[j].Count != 0)
{
info.features[1] = info.features[1] + (j + 1) + ",";
}
}
info.splitIndex = i;
}
}
#endregion
#region 連續變數
else
{
double[] leftCunt = new double[classCount];
//做節點各個類別的數量
double[] rightCount = new double[classCount];
//右節點各個類別的數量
double[] count1 = new double[classCount];
//子集1的統計量
double[] count2 = new double
[node.ClassCount.Length]; //子集2的統計量
for (int j = 0; j < node.ClassCount.Length;
j++)
{
count2[j] = node.ClassCount[j];
}
int all1 = 0;
//子集1的樣本量
int all2 = nums.Count;
//子集2的樣本量
double lastValue = 0;
//上一個記錄的類別
double currentValue = 0;
//當前類別
double lastPoint = 0;
//上一個點的值
double currentPoint = 0;
//當前點的值
double[] values = new double[nums.Count];
for (int j = 0; j < values.Length; j++)
{
values[j] = allData[nums[j]][i];
}
QSort(values, nums, 0, nums.Count - 1);
double lianxuMax = 1;
//連續型屬性的最大熵
#region 尋找最佳的分割點
for (int j = 0; j < nums.Count - 1; j++)
{
currentValue = allData[nums[j]][lieshu -
1];
currentPoint = (allData[nums[j]][i]);
if (j == 0)
{
lastValue = currentValue;
lastPoint = currentPoint;
}
if (currentValue != lastValue &&
currentPoint != lastPoint)
{
double shang1 = getGini(count1,
all1);
double shang2 = getGini(count2,
all2);
double allShang = shang1 * all1 /
(all1 + all2) + shang2 * all2 / (all1 + all2);
//allShang = (totalShang - allShang);
if (lianxuMax > allShang)
{
lianxuMax = allShang;
for (int k = 0; k <
count1.Length; k++)
{
leftCunt[k] = count1[k];
rightCount[k] = count2[k];
}
splitPoint = j;
splitValue = (currentPoint +
lastPoint) / 2;
}
}
all1++;
count1[Convert.ToInt32(currentValue) -
1]++;
count2[Convert.ToInt32(currentValue) -
1]--;
all2--;
lastValue = currentValue;
lastPoint = currentPoint;
}
#endregion
#region 如果超過了區域性值,重設
if (lianxuMax < jubuMax)
{
info.type = 1;
info.splitIndex = i;
info.features=new List<string>()
{splitValue+""};
//finalPoint = splitPoint;
jubuMax = lianxuMax;
info.temp[0] = new List<int>();
info.temp[1] = new List<int>();
for (int k = 0; k < splitPoint; k++)
{
info.temp[0].Add(nums[k]);
}
for (int k = splitPoint; k < nums.Count;
k++)
{
info.temp[1].Add(nums[k]);
}
info.class_Count[0] = new double
[leftCunt.Length];
info.class_Count[1] = new double
[leftCunt.Length];
for (int k = 0; k < leftCunt.Length; k++)
{
info.class_Count[0][k] = leftCunt[k];
info.class_Count[1][k] = rightCount
[k];
}
}
#endregion
}
#endregion
}
#region 沒有尋找到最佳的分裂點,則設定為葉節點
if (info.splitIndex == -1)
{
double[] finalCount = node.ClassCount;
double max = finalCount[0];
int result = 1;
for (int i = 1; i < finalCount.Length; i++)
{
if (finalCount[i] > max)
{
max = finalCount[i];
result = (i + 1);
}
}
node.feature_Type="result";
node.features=new List<String> { "" + result };
return node;
}
#endregion
#region 分裂
int deep = node.deep;
node.SplitFeature = ("" + info.splitIndex);
List<Node> childNode = new List<Node>();
int[][] used = new int[2][];
used[0] = new int[isUsed.Length];
used[1] = new int[isUsed.Length];
for (int i = 0; i < isUsed.Length; i++)
{
used[0][i] = isUsed[i];
used[1][i] = isUsed[i];
}
if (info.type == 0)
{
used[0][info.splitIndex] = 1;
node.feature_Type = ("離散");
}
else
{
//used[info.splitIndex] = 0;
node.feature_Type = ("連續");
}
List<int>[] rowIndex = info.temp;
List<String> features = info.features;
Node node1 = new Node();
Node node2 = new Node();
node1.setClassCount(info.class_Count[0]);
node2.setClassCount(info.class_Count[1]);
node1.rowCount = info.temp[0].Count;
node2.rowCount = info.temp[1].Count;
node1.deep = deep + 1;
node2.deep = deep + 1;
node1 = findBestSplit(node1, info.temp[0],used[0]);
node2 = findBestSplit(node2, info.temp[1], used[1]);
node.leafNode_Count = (node1.leafNode_Count
+node2.leafNode_Count);
node.leafWrong = (node1.leafWrong+node2.leafWrong);
node.features = (features);
childNode.Add(node1);
childNode.Add(node2);
node.childNodes = childNode;
#endregion
return node;
}
catch (Exception e)
{
Console.WriteLine(e.StackTrace);
return node;
}
}
節點選擇屬性和分裂
(4)剪枝
代價複雜度剪枝方法(CCP):
CCP代價複雜度剪枝
public static void getSeries(Node node)
{
Stack<Node> nodeStack = new Stack<Node>();
if (node != null)
{
nodeStack.Push(node);
}
if (node.feature_Type == "result")
return;
List<Node> childs = node.childNodes;
for (int i = 0; i < childs.Count; i++)
{
getSeries(node);
}
}
CCP代價複雜度剪枝
CART全部核心程式碼:
CART核心程式碼
/// <summary>
/// 判斷是否還需要分裂
/// </summary>
/// <param name="node"></param>
/// <returns></returns>
public static bool ifEnd(Node node, double shang,int[] isUsed)
{
try
{
double[] count = node.ClassCount;
int rowCount = node.rowCount;
int maxResult = 0;
double maxRate = 0;
#region 數達到某一深度
int deep = node.deep;
if (deep >= 10)
{
maxResult = node.result + 1;
node.feature_Type="result";
node.features=new List<String>() { maxResult + ""
};
node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
node.leafNode_Count=1;
return true;
}
#endregion
#region 純度(其實跟後面的有點重了,記得要修改)
//maxResult = 1;
//for (int i = 1; i < count.Length; i++)
//{
// if (count[i] / rowCount >= 0.95)
// {
// node.feature_Type="result";
// node.features=new List<String> { "" + (i +
1) };
// node.leafNode_Count=1;
// node.leafWrong=rowCount - Convert.ToInt32
(count[i]);
// return true;
// }
//}
#endregion
#region 熵為0
if (shang == 0)
{
maxRate = count[0] / rowCount;
maxResult = 1;
for (int i = 1; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + 1;
}
}
node.feature_Type="result";
node.features=new List<String> { maxResult + ""
};
node.leafWrong=rowCount - Convert.ToInt32(count
[maxResult - 1]);
node.leafNode_Count=1;
return true;
}
#endregion
#region 屬性已經分完
//int[] isUsed = node.getUsed();
bool flag = true;
for (int i = 0; i < isUsed.Length - 1; i++)
{
if (isUsed[i] == 0)
{
flag = false;
break;
}
}
if (flag)
{
maxRate = count[0] / rowCount;
maxResult = 1;
for (int i = 1; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + 1;
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" +
(maxResult) });
node.leafWrong=(rowCount - Convert.ToInt32(count
[maxResult - 1]));
node.leafNode_Count=(1);
return true;
}
#endregion
#region 幾點數少於100
if (rowCount < Limit_Node)
{
maxRate = count[0] / rowCount;
maxResult = 1;
for (int i = 1; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + 1;
}
}
node.feature_Type="result";
node.features=new List<String> { "" + (maxResult)
};
node.leafWrong=rowCount - Convert.ToInt32(count
[maxResult - 1]);
node.leafNode_Count=1;
return true;
}
#endregion
return false;
}
catch (Exception e)
{
return false;
}
}
#region 排序演算法
public static void InsertSort(double[] values, List<int> arr,
int StartIndex, int endIndex)
{
for (int i = StartIndex + 1; i <= endIndex; i++)
{
int key = arr[i];
double init = values[i];
int j = i - 1;
while (j >= StartIndex && values[j] > init)
{
arr[j + 1] = arr[j];
values[j + 1] = values[j];
j--;
}
arr[j + 1] = key;
values[j + 1] = init;
}
}
static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
{
int mid = low + ((high - low) >> 1);//計算陣列中間的元素的下標
//使用三數取中法選擇樞軸
if (values[mid] > values[high])//目標: arr[mid] <= arr[high]
{
swap(values, arr, mid, high);
}
if (values[low] > values[high])//目標: arr[low] <= arr[high]
{
swap(values, arr, low, high);
}
if (values[mid] > values[low]) //目標: arr[low] >= arr[mid]
{
swap(values, arr, mid, low);
}
//此時,arr[mid] <= arr[low] <= arr[high]
return low;
//low的位置上儲存這三個位置中間的值
//分割時可以直接使用low位置的元素作為樞軸,而不用改變分割函數了
}
static void swap(double[] values, List<int> arr, int t1, int t2)
{
double temp = values[t1];
values[t1] = values[t2];
values[t2] = temp;
int key = arr[t1];
arr[t1] = arr[t2];
arr[t2] = key;
}
static void QSort(double[] values, List<int> arr, int low, int high)
{
int first = low;
int last = high;
int left = low;
int right = high;
int leftLen = 0;
int rightLen = 0;
if (high - low + 1 < 10)
{
InsertSort(values, arr, low, high);
return;
}
//一次分割
int key = SelectPivotMedianOfThree(values, arr, low,
high);//使用三數取中法選擇樞軸
double inti = values[key];
int currentKey = arr[key];
while (low < high)
{
while (high > low && values[high] >= inti)
{
if (values[high] == inti)//處理相等元素
{
swap(values, arr, right, high);
right--;
rightLen++;
}
high--;
}
arr[low] = arr[high];
values[low] = values[high];
while (high > low && values[low] <= inti)
{
if (values[low] == inti)
{
swap(values, arr, left, low);
left++;
leftLen++;
}
low++;
}
arr[high] = arr[low];
values[high] = values[low];
}
arr[low] = currentKey;
values[low] = values[key];
//一次快排結束
//把與樞軸key相同的元素移到樞軸最終位置周圍
int i = low - 1;
int j = first;
while (j < left && values[i] != inti)
{
swap(values, arr, i, j);
i--;
j++;
}
i = low + 1;
j = last;
while (j > right && values[i] != inti)
{
swap(values, arr, i, j);
i++;
j--;
}
QSort(values, arr, first, low - 1 - leftLen);
QSort(values, arr, low + 1 + rightLen, last);
}
#endregion
/// <summary>
/// 尋找最佳的分裂點
/// </summary>
/// <param name="num"></param>
/// <param name="node"></param>
public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
{
try
{
//判斷是否繼續分裂
double totalShang = getGini(node.ClassCount, node.rowCount);
if (ifEnd(node, totalShang, isUsed))
{
return node;
}
#region 變數宣告
SplitInfo info = new SplitInfo();
info.initial();
int RowCount = nums.Count; //樣本總數
double jubuMax = 1; //區域性最大熵
int splitPoint = 0; //分裂的點
double splitValue = 0; //分裂的值
#endregion
for (int i = 0; i < isUsed.Length - 1; i++)
{
if (isUsed[i] == 1)
{
continue;
}
#region 離散變數
if (type[i] == 0)
{
double[][] allCount = new double[allNum[i]][];
for (int j = 0; j < allCount.Length; j++)
{
allCount[j] = new double[classCount];
}
int[] countAllFeature = new int[allNum[i]];
List<int>[] temp = new List<int>[allNum[i]];
double[] allClassCount = node.ClassCount; //所有類別的數量
for (int j = 0; j < temp.Length; j++)
{
temp[j] = new List<int>();
}
for (int j = 0; j < nums.Count; j++)
{
int index = Convert.ToInt32(allData[nums[j]][i]);
temp[index - 1].Add(nums[j]);
countAllFeature[index - 1]++;
allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
}
double allShang = 1;
int choose = 0;
double[][] jubuCount = new double[2][];
for (int k = 0; k < allCount.Length; k++)
{
if (temp[k].Count == 0)
continue;
double JubuShang = 0;
double[][] tempCount = new double[2][];
tempCount[0] = allCount[k];
tempCount[1] = new double[allCount[0].Length];
for (int j = 0; j < tempCount[1].Length; j++)
{
tempCount[1][j] = allClassCount[j] - allCount[k][j];
}
JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
int nodecount = RowCount - countAllFeature[k];
JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
if (JubuShang < allShang)
{
allShang = JubuShang;
jubuCount = tempCount;
choose = k;
}
}
if (allShang < jubuMax)
{
info.type = 0;
jubuMax = allShang;
info.class_Count = jubuCount;
info.temp[0] = temp[choose];
info.temp[1] = new List<int>();
info.features = new List<string>();
info.features.Add((choose + 1) + "");
info.features.Add("");
for (int j = 0; j < temp.Length; j++)
{
if (j == choose)
continue;
for (int k = 0; k < temp[j].Count; k++)
{
info.temp[1].Add(temp[j][k]);
}
if (temp[j].Count != 0)
{
info.features[1] = info.features[1] + (j + 1) + ",";
}
}
info.splitIndex = i;
}
}
#endregion
#region 連續變數
else
{
double[] leftCunt = new double[classCount];
//做節點各個類別的數量
double[] rightCount = new double[classCount];
//右節點各個類別的數量
double[] count1 = new double[classCount];
//子集1的統計量
double[] count2 = new double
[node.ClassCount.Length]; //子集2的統計量
for (int j = 0; j < node.ClassCount.Length;
j++)
{
count2[j] = node.ClassCount[j];
}
int all1 = 0;
//子集1的樣本量
int all2 = nums.Count;
//子集2的樣本量
double lastValue = 0;
//上一個記錄的類別
double currentValue = 0;
//當前類別
double lastPoint = 0;
//上一個點的值
double currentPoint = 0;
//當前點的值
double[] values = new double[nums.Count];
for (int j = 0; j < values.Length; j++)
{
values[j] = allData[nums[j]][i];
}
QSort(values, nums, 0, nums.Count - 1);
double lianxuMax = 1;
//連續型屬性的最大熵
#region 尋找最佳的分割點
for (int j = 0; j < nums.Count - 1; j++)
{
currentValue = allData[nums[j]][lieshu -
1];
currentPoint = (allData[nums[j]][i]);
if (j == 0)
{
lastValue = currentValue;
lastPoint = currentPoint;
}
if (currentValue != lastValue &&
currentPoint != lastPoint)
{
double shang1 = getGini(count1,
all1);
double shang2 = getGini(count2,
all2);
double allShang = shang1 * all1 /
(all1 + all2) + shang2 * all2 / (all1 + all2);
//allShang = (totalShang - allShang);
if (lianxuMax > allShang)
{
lianxuMax = allShang;
for (int k = 0; k <
count1.Length; k++)
{
leftCunt[k] = count1[k];
rightCount[k] = count2[k];
}
splitPoint = j;
splitValue = (currentPoint +
lastPoint) / 2;
}
}
all1++;
count1[Convert.ToInt32(currentValue) -
1]++;
count2[Convert.ToInt32(currentValue) -
1]--;
all2--;
lastValue = currentValue;
lastPoint = currentPoint;
}
#endregion
#region 如果超過了區域性值,重設
if (lianxuMax < jubuMax)
{
info.type = 1;
info.splitIndex = i;
info.features=new List<string>()
{splitValue+""};
//finalPoint = splitPoint;
jubuMax = lianxuMax;
info.temp[0] = new List<int>();
info.temp[1] = new List<int>();
for (int k = 0; k < splitPoint; k++)
{
info.temp[0].Add(nums[k]);
}
for (int k = splitPoint; k < nums.Count;
k++)
{
info.temp[1].Add(nums[k]);
}
info.class_Count[0] = new double
[leftCunt.Length];
info.class_Count[1] = new double
[leftCunt.Length];
for (int k = 0; k < leftCunt.Length; k++)
{
info.class_Count[0][k] = leftCunt[k];
info.class_Count[1][k] = rightCount
[k];
}
}
#endregion
}
#endregion
}
#region 沒有尋找到最佳的分裂點,則設定為葉節點
if (info.splitIndex == -1)
{
double[] finalCount = node.ClassCount;
double max = finalCount[0];
int result = 1;
for (int i = 1; i < finalCount.Length; i++)
{
if (finalCount[i] > max)
{
max = finalCount[i];
result = (i + 1);
}
}
node.feature_Type="result";
node.features=new List<String> { "" + result };
return node;
}
#endregion
#region 分裂
int deep = node.deep;
node.SplitFeature = ("" + info.spli