1. 程式人生 > >資料探勘--Cart演算法的實現

資料探勘--Cart演算法的實現

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;




public class Cart {
	String Var="";
	public float Gini_compute(List<String> Target,String Split){//函式作用:計算給定屬性劃分的Gini指數值,其中Target為二維向量集合,第一維表示屬性,第二維表示種類.
		//格式Target:a1 c1    split:a1 a2 a3
		List<String> Target1=new ArrayList<String> ();
		List<String> Target2=new ArrayList<String> ();
		String[] Split_set=Split.split(" ");
		Iterator<String> Iter=Target.iterator();
		while(Iter.hasNext()){
			String tmp=Iter.next();
			String[] tmp_set=tmp.split(" ");
			int in_Split=0;
			for(int i=0;i<Split_set.length;i++){
			    if(Split_set[i].equals(tmp_set[0])){in_Split=1;break;}
			}
			if(in_Split==1){Target1.add(tmp);}else{Target2.add(tmp);}
		}
		float Gini=0;
		Gini=Gini_index(Target1)*((float)Target1.size())/(Target1.size()+Target2.size());
		Gini +=Gini_index(Target2)*((float)Target2.size())/(Target1.size()+Target2.size());
		Gini=Gini_index(Target)-Gini;
		return Gini;
	}
	
	public float Gini_index(List<String> Target){//函式作用:計算給集合的Gini指標計算.
		String[] Terget_array=new String[Target.size()];
		Set<String> Target_set=new HashSet<String>();
		Iterator<String> Iter=Target.iterator();
		int i=0;
		while(Iter.hasNext()){
			Terget_array[i]=Iter.next().split(" ")[1];
			Target_set.add(Terget_array[i]);
			i=i+1;
		}
		int[] count=new int[Target_set.size()];
		float[] p=new float[Target_set.size()];
		Iterator<String> Iter1=Target_set.iterator();
		i=0;
		while(Iter1.hasNext()){
			count[i]=0;
			String tmp=Iter1.next();
			for(int j=0;j<Terget_array.length;j++){
				if(Terget_array[j].equals(tmp)){count[i] +=1;}
			}
			p[i]=(((float)count[i])/Terget_array.length)*(((float)count[i])/Terget_array.length);
			i=i+1;			
		}
		float sum=0;
		for(i=0;i<p.length;i++){
			sum=sum+p[i];
		}
		return 1-sum;
	}
	
	public List<String> Gini_select(List<String> DataSet,int i){//函式作用:計算DataSet中第i列指標的最優屬性劃分
		List<String> DataSet_i=new ArrayList<String>();
		Set<String> DataSet_i_set=new HashSet<String>();
		Iterator<String> Iter=DataSet.iterator();
		while(Iter.hasNext()){
			String[] tmp=Iter.next().split(" ");
			DataSet_i.add(tmp[i]+" "+tmp[tmp.length-1]);
			DataSet_i_set.add(tmp[i]);
		}
		String set_i="";
		Iterator<String> Iter1=DataSet_i_set.iterator();
		while(Iter1.hasNext()){
			set_i=set_i+" "+Iter1.next();
		}
		set_i=set_i.trim();
		ArrayList<String> list = new ArrayList<String>();
		doGetSubSequences(set_i,"",list);
		String max_set=list.get(0);
		float max=Gini_compute(DataSet_i,max_set);
		for(int j=1;j<list.size();j++){
			if(Gini_compute(DataSet_i,list.get(j))>max)
			{max=Gini_compute(DataSet_i,list.get(j));max_set=list.get(j);}
		}
		List<String> return_list=new ArrayList<String>();
		return_list.add(max_set);
		return_list.add(String.valueOf(max));
		return return_list;
	}
	
	private static void doGetSubSequences(String word, String s,ArrayList<String> list) {
		if (word.length() == 0) {//函式作用:給定集合的所有子集
			s=s.trim();
			list.add(s);
			return;
		}
		String tail="";
		if(word.split(" ",2).length>=2)
		{tail= word.split(" ",2)[1];}
		doGetSubSequences(tail, s, list);
		doGetSubSequences(tail, s + " "+word.split(" ",2)[0], list);
	}
	
	public void Cart_tree(List<String> DataSet,String path,int alpha,int alpha_max){
		if(alpha==alpha_max | DataSet.size()<=2){//cart決策樹,終止條件1
			write_result(DataSet,path);
			return;
		}
		int count_var=DataSet.get(0).split(" ").length-1;
		String max_split_L="";
		float max_Gini=-1;
		int max_index=-1;
		for(int i=0;i<count_var;i++){
			if(Float.parseFloat(Gini_select(DataSet,i).get(1))>max_Gini){
				max_Gini=Float.parseFloat(Gini_select(DataSet,i).get(1));
				max_split_L=Gini_select(DataSet,i).get(0);
				max_index=i;
			}
		}
		if(max_Gini<=0.01){//cart決策樹,終止條件2
			write_result(DataSet,path);
			return;
		}
		List<String> DataSet_L=new ArrayList<String>();
		List<String> DataSet_R=new ArrayList<String>();
		DataSet_split(DataSet,max_index,max_split_L,DataSet_L,DataSet_R);
		String max_split_R=Compute_split_R(DataSet,max_index,max_split_L);
		Cart_tree(DataSet_L,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_L,alpha+1,alpha_max);
		Cart_tree(DataSet_R,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_R,alpha+1,alpha_max);
	}
	
	private void write_result(List<String> DataSet, String path) {//函式作用:輸出cart葉子節點的結果
		String[] Category=new String[DataSet.size()];
		for(int i=0;i<Category.length;i++){
			Category[i]=DataSet.get(i).trim().split(" ")[DataSet.get(i).trim().split(" ").length-1];
		}
		Map<String,Integer> map=new HashMap<String,Integer>();
		for(int i=0;i<Category.length;i++){
			if(!map.containsKey(Category[i])){
				map.put(Category[i], 1);
			}else{
				map.put(Category[i], map.get(Category[i])+1);
			}
		}
		int sum_count=0;
		int max_count=0;
		String max_Category="";
		Iterator<String> Iter=map.keySet().iterator();
		while(Iter.hasNext()){
			String tmp=Iter.next();
			if(map.get(tmp)>=max_count){
				max_count=map.get(tmp);
				max_Category=tmp;
			}
			sum_count=sum_count+map.get(tmp);
		}
		int count=DataSet.size();
	    String forcast=max_Category;
	    float accuracy_rate=((float)max_count)/sum_count;
	    System.out.println("Rule:"+path+".   Count:"+count+".   "+this.Var.split(" ")[this.Var.split(" ").length-1]+":"+forcast+".   Accuracy_rate:"+accuracy_rate);
	}

	private String Compute_split_R(List<String> DataSet, int index,
			String split_L) {//函式作用:DataSet中第index列中,屬性一半劃分為split_L,輸出另外的一半劃分split_R
		String split_R="";
		Set<String> set=new HashSet<String>();
		for(int i=0;i<DataSet.size();i++){
			set.add(DataSet.get(i).split(" ")[index]);
			}
		for(int i=0;i<split_L.trim().split(" ").length;i++){
			set.remove(split_L.trim().split(" ")[i]);
		}
		Iterator<String> Iter=set.iterator();
		while(Iter.hasNext()){
			split_R=split_R+" "+Iter.next();
		}
		return split_R.trim();
	}

	private void DataSet_split(List<String> DataSet, int max_index,
			String max_split_L, List<String> DataSet_L, List<String> DataSet_R) {
		for(int i=0;i<DataSet.size();i++){//函式作用:DataSet第max_index列按照屬性max_split_L劃分後的兩個數集為DataSet_L,DataSet_R.
			int i_in_L=0;
			for(int j=0;j<max_split_L.trim().split(" ").length;j++){
				if(DataSet.get(i).split(" ")[max_index].equals(max_split_L.trim().split(" ")[j])){
					DataSet_L.add(DataSet.get(i));
					i_in_L=1;
					break;
				}
			}
			if(i_in_L==0){DataSet_R.add(DataSet.get(i));}
		}
	}

	public static void main(String[] args) throws IOException {
		BufferedReader br=new BufferedReader(new FileReader("F:/資料探勘--演算法實現/cart演算法/input.txt"));  
        String line="";
        int i=0;
        List<String> DataSet=new ArrayList<String>();
        String Var="";
        while((line=br.readLine())!=null){
        	if(i==0){i=1;Var=line;continue;}
        	DataSet.add(line);
        }
        Cart a=new Cart();
        a.Var=Var;
		a.Cart_tree(DataSet,"",0,2);
	}

}

輸入:

age income student credit_rating buys_computer
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior medium no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged medium no excellent yes
middle_aged high yes fair yes
senior medium no excellent no

資料格式說明:第一行表示變數名,其中buys_computer是目標變數,其餘的行表示使用者資料,每個資料單元以空格分開

輸出結果:

Rule:|age:middle_aged.   Count:4.   buys_computer:yes.   Accuracy_rate:1.0
Rule:|age:senior youth|student:yes.   Count:5.   buys_computer:yes.   Accuracy_rate:0.8
Rule:|age:senior youth|student:no.   Count:5.   buys_computer:no.   Accuracy_rate:0.8