資料探勘--Cart演算法的實現
import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; public class Cart { String Var=""; public float Gini_compute(List<String> Target,String Split){//函式作用:計算給定屬性劃分的Gini指數值,其中Target為二維向量集合,第一維表示屬性,第二維表示種類. //格式Target:a1 c1 split:a1 a2 a3 List<String> Target1=new ArrayList<String> (); List<String> Target2=new ArrayList<String> (); String[] Split_set=Split.split(" "); Iterator<String> Iter=Target.iterator(); while(Iter.hasNext()){ String tmp=Iter.next(); String[] tmp_set=tmp.split(" "); int in_Split=0; for(int i=0;i<Split_set.length;i++){ if(Split_set[i].equals(tmp_set[0])){in_Split=1;break;} } if(in_Split==1){Target1.add(tmp);}else{Target2.add(tmp);} } float Gini=0; Gini=Gini_index(Target1)*((float)Target1.size())/(Target1.size()+Target2.size()); Gini +=Gini_index(Target2)*((float)Target2.size())/(Target1.size()+Target2.size()); Gini=Gini_index(Target)-Gini; return Gini; } public float Gini_index(List<String> Target){//函式作用:計算給集合的Gini指標計算. String[] Terget_array=new String[Target.size()]; Set<String> Target_set=new HashSet<String>(); Iterator<String> Iter=Target.iterator(); int i=0; while(Iter.hasNext()){ Terget_array[i]=Iter.next().split(" ")[1]; Target_set.add(Terget_array[i]); i=i+1; } int[] count=new int[Target_set.size()]; float[] p=new float[Target_set.size()]; Iterator<String> Iter1=Target_set.iterator(); i=0; while(Iter1.hasNext()){ count[i]=0; String tmp=Iter1.next(); for(int j=0;j<Terget_array.length;j++){ if(Terget_array[j].equals(tmp)){count[i] +=1;} } p[i]=(((float)count[i])/Terget_array.length)*(((float)count[i])/Terget_array.length); i=i+1; } float sum=0; for(i=0;i<p.length;i++){ sum=sum+p[i]; } return 1-sum; } public List<String> Gini_select(List<String> DataSet,int i){//函式作用:計算DataSet中第i列指標的最優屬性劃分 List<String> DataSet_i=new ArrayList<String>(); Set<String> DataSet_i_set=new HashSet<String>(); Iterator<String> Iter=DataSet.iterator(); while(Iter.hasNext()){ String[] tmp=Iter.next().split(" "); DataSet_i.add(tmp[i]+" "+tmp[tmp.length-1]); DataSet_i_set.add(tmp[i]); } String set_i=""; Iterator<String> Iter1=DataSet_i_set.iterator(); while(Iter1.hasNext()){ set_i=set_i+" "+Iter1.next(); } set_i=set_i.trim(); ArrayList<String> list = new ArrayList<String>(); doGetSubSequences(set_i,"",list); String max_set=list.get(0); float max=Gini_compute(DataSet_i,max_set); for(int j=1;j<list.size();j++){ if(Gini_compute(DataSet_i,list.get(j))>max) {max=Gini_compute(DataSet_i,list.get(j));max_set=list.get(j);} } List<String> return_list=new ArrayList<String>(); return_list.add(max_set); return_list.add(String.valueOf(max)); return return_list; } private static void doGetSubSequences(String word, String s,ArrayList<String> list) { if (word.length() == 0) {//函式作用:給定集合的所有子集 s=s.trim(); list.add(s); return; } String tail=""; if(word.split(" ",2).length>=2) {tail= word.split(" ",2)[1];} doGetSubSequences(tail, s, list); doGetSubSequences(tail, s + " "+word.split(" ",2)[0], list); } public void Cart_tree(List<String> DataSet,String path,int alpha,int alpha_max){ if(alpha==alpha_max | DataSet.size()<=2){//cart決策樹,終止條件1 write_result(DataSet,path); return; } int count_var=DataSet.get(0).split(" ").length-1; String max_split_L=""; float max_Gini=-1; int max_index=-1; for(int i=0;i<count_var;i++){ if(Float.parseFloat(Gini_select(DataSet,i).get(1))>max_Gini){ max_Gini=Float.parseFloat(Gini_select(DataSet,i).get(1)); max_split_L=Gini_select(DataSet,i).get(0); max_index=i; } } if(max_Gini<=0.01){//cart決策樹,終止條件2 write_result(DataSet,path); return; } List<String> DataSet_L=new ArrayList<String>(); List<String> DataSet_R=new ArrayList<String>(); DataSet_split(DataSet,max_index,max_split_L,DataSet_L,DataSet_R); String max_split_R=Compute_split_R(DataSet,max_index,max_split_L); Cart_tree(DataSet_L,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_L,alpha+1,alpha_max); Cart_tree(DataSet_R,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_R,alpha+1,alpha_max); } private void write_result(List<String> DataSet, String path) {//函式作用:輸出cart葉子節點的結果 String[] Category=new String[DataSet.size()]; for(int i=0;i<Category.length;i++){ Category[i]=DataSet.get(i).trim().split(" ")[DataSet.get(i).trim().split(" ").length-1]; } Map<String,Integer> map=new HashMap<String,Integer>(); for(int i=0;i<Category.length;i++){ if(!map.containsKey(Category[i])){ map.put(Category[i], 1); }else{ map.put(Category[i], map.get(Category[i])+1); } } int sum_count=0; int max_count=0; String max_Category=""; Iterator<String> Iter=map.keySet().iterator(); while(Iter.hasNext()){ String tmp=Iter.next(); if(map.get(tmp)>=max_count){ max_count=map.get(tmp); max_Category=tmp; } sum_count=sum_count+map.get(tmp); } int count=DataSet.size(); String forcast=max_Category; float accuracy_rate=((float)max_count)/sum_count; System.out.println("Rule:"+path+". Count:"+count+". "+this.Var.split(" ")[this.Var.split(" ").length-1]+":"+forcast+". Accuracy_rate:"+accuracy_rate); } private String Compute_split_R(List<String> DataSet, int index, String split_L) {//函式作用:DataSet中第index列中,屬性一半劃分為split_L,輸出另外的一半劃分split_R String split_R=""; Set<String> set=new HashSet<String>(); for(int i=0;i<DataSet.size();i++){ set.add(DataSet.get(i).split(" ")[index]); } for(int i=0;i<split_L.trim().split(" ").length;i++){ set.remove(split_L.trim().split(" ")[i]); } Iterator<String> Iter=set.iterator(); while(Iter.hasNext()){ split_R=split_R+" "+Iter.next(); } return split_R.trim(); } private void DataSet_split(List<String> DataSet, int max_index, String max_split_L, List<String> DataSet_L, List<String> DataSet_R) { for(int i=0;i<DataSet.size();i++){//函式作用:DataSet第max_index列按照屬性max_split_L劃分後的兩個數集為DataSet_L,DataSet_R. int i_in_L=0; for(int j=0;j<max_split_L.trim().split(" ").length;j++){ if(DataSet.get(i).split(" ")[max_index].equals(max_split_L.trim().split(" ")[j])){ DataSet_L.add(DataSet.get(i)); i_in_L=1; break; } } if(i_in_L==0){DataSet_R.add(DataSet.get(i));} } } public static void main(String[] args) throws IOException { BufferedReader br=new BufferedReader(new FileReader("F:/資料探勘--演算法實現/cart演算法/input.txt")); String line=""; int i=0; List<String> DataSet=new ArrayList<String>(); String Var=""; while((line=br.readLine())!=null){ if(i==0){i=1;Var=line;continue;} DataSet.add(line); } Cart a=new Cart(); a.Var=Var; a.Cart_tree(DataSet,"",0,2); } }
輸入:
age income student credit_rating buys_computer
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior medium no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged medium no excellent yes
middle_aged high yes fair yes
senior medium no excellent no
資料格式說明:第一行表示變數名,其中buys_computer是目標變數,其餘的行表示使用者資料,每個資料單元以空格分開
輸出結果:
Rule:|age:middle_aged. Count:4. buys_computer:yes. Accuracy_rate:1.0
Rule:|age:senior youth|student:yes. Count:5. buys_computer:yes. Accuracy_rate:0.8
Rule:|age:senior youth|student:no. Count:5. buys_computer:no. Accuracy_rate:0.8