1. 程式人生 > >決策樹演算法原理及JAVA實現(ID3)

決策樹演算法原理及JAVA實現(ID3)

package sequence.machinelearning.decisiontree.myid3;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.LinkedList;

public class MyID3 {

    private static LinkedList<String> attribute = new LinkedList<String>(); // 儲存屬性的名稱
    private static LinkedList<ArrayList<String>> attributevalue = new LinkedList<ArrayList<String>>(); // 儲存每個屬性的取值
    private static LinkedList<String[]> data = new LinkedList<String[]>();; // 原始資料
   
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
	public static String[] yesNo;
	public static TreeNode root;
	


    /**
     * 
     * @param lines 傳入要分析的資料集
     * @param index 哪個屬性?attribute的index
     */
    public Double getGain(LinkedList<String[]> lines,int index){
    	Double gain=-1.0;
    	List<Double> li=new ArrayList<Double>();
    	//統計Yes No的次數
    	for(int i=0;i<yesNo.length;i++){
    		Double sum=0.0;
    		for(int j=0;j<lines.size();j++){
    			String[] line=lines.get(j);
    			//data為結構化資料,如果資料最後一列==yes,sum+1
    			if(line[line.length-1].equals(yesNo[i])){
    				sum=sum+1;
    			}
    		}
    		li.add(sum);
    	}
    	//計算Entropy(S)計算Entropy(S) 見參考書《機器學習 》Tom.Mitchell著  第3.4.1.2節
    	Double entropyS=TheMath.getEntropy(lines.size(), li);
    	//下面計算gain
    	
    	List<String> la=attributevalue.get(index);
    	List<Point> lasv=new ArrayList<Point>();
    	for(int n=0;n<la.size();n++){
    		String attvalue=la.get(n);
        	//統計Yes No的次數
    		List<Double> lisub=new ArrayList<Double>();//如:sunny 是yes時發生的次數,是no發生的次數
    		Double Sv=0.0;//公式3.4中的Sv 見參考書《機器學習(Tom.Mitchell著)》
        	for(int i=0;i<yesNo.length;i++){
        		Double sum=0.0;
        		for(int j=0;j<lines.size();j++){
        			String[] line=lines.get(j);
        			//data為結構化資料,如果資料最後一列==yes,sum+1
        			if(line[index].equals(attvalue)&&line[line.length-1].equals(yesNo[i])){
        				sum=sum+1;
        			}
        		}
        		Sv=Sv+sum;//計算總數
        		lisub.add(sum);
        	}
        	//計算Entropy(S) 見參考書《機器學習(Tom.Mitchell著)》
        	Double entropySv=TheMath.getEntropy(Sv.intValue(), lisub);
        	//
        	Point p=new Point();
        	p.setSv(Sv);
        	p.setEntropySv(entropySv);
        	lasv.add(p);
    	}
    	gain=TheMath.getGain(entropyS,lines.size(),lasv);
    	return gain;
    }
    //尋找最大的資訊增益,將最大的屬性定為當前節點,並返回該屬性所在list的位置和gain值
    public Maxgain getMaxGain(LinkedList<String[]> lines){
    	if(lines==null||lines.size()<=0){
    		return null;
    	}
    	Maxgain maxgain = new Maxgain();
    	Double maxvalue=0.0;
    	int maxindex=-1;
    	for(int i=0;i<attribute.size();i++){
    		Double tmp=getGain(lines,i);
    		if(maxvalue< tmp){
    			maxvalue=tmp;
    			maxindex=i;
    		}
    	}
    	maxgain.setMaxgain(maxvalue);
    	maxgain.setMaxindex(maxindex);
    	return maxgain;
    }
    //剪取陣列
    public LinkedList<String[]>  filterLines(LinkedList<String[]> lines, String attvalue, int index){
    	LinkedList<String[]> newlines=new LinkedList<String[]>();
    	for(int i=0;i<lines.size();i++){
    		String[] line=lines.get(i);
    		if(line[index].equals(attvalue)){
    			newlines.add(line);
    		}
    	}
    	
    	return newlines;
    }
    public void createDTree(){
    	root=new TreeNode();
    	Maxgain maxgain=getMaxGain(data);
    	if(maxgain==null){
    		System.out.println("沒有資料集,請檢查!");
    	}
    	int maxKey=maxgain.getMaxindex();
    	String nodename=attribute.get(maxKey);
    	root.setName(nodename);
    	root.setLiatts(attributevalue.get(maxKey));
    	insertNode(data,root,maxKey);
    }
    /**
     * 
     * @param lines 傳入的資料集,作為新的遞迴資料集
     * @param node 深入此節點
     * @param index 屬性位置
     */
    public void insertNode(LinkedList<String[]> lines,TreeNode node,int index){
    	List<String> liatts=node.getLiatts();
    	for(int i=0;i<liatts.size();i++){
    		String attname=liatts.get(i);
    		LinkedList<String[]> newlines=filterLines(lines,attname,index);
    		if(newlines.size()<=0){
    	    	System.out.println("出現異常,迴圈結束");
    	    	return;
    	    }
    		Maxgain maxgain=getMaxGain(newlines);
    		double gain=maxgain.getMaxgain();
    		Integer maxKey=maxgain.getMaxindex();
    		//不等於0繼續遞迴,等於0說明是葉子節點,結束遞迴。
    		if(gain!=0){
    			TreeNode subnode=new TreeNode();
    			subnode.setParent(node);
    			subnode.setFatherAttribute(attname);
    			String nodename=attribute.get(maxKey);
    			subnode.setName(nodename);
    			subnode.setLiatts(attributevalue.get(maxKey));
    			node.addChild(subnode);
    			//不等於0,繼續遞迴
    			insertNode(newlines,subnode,maxKey);
    		}else{
    			TreeNode subnode=new TreeNode();
    			subnode.setParent(node);
    			subnode.setFatherAttribute(attname);
    			//葉子節點是yes還是no?取新行中最後一個必是其名稱,因為只有完全是yes,或完全是no的情況下才會是葉子節點
    			String[] line=newlines.get(0);
    			String nodename=line[line.length-1];
    			subnode.setName(nodename);
    			node.addChild(subnode);
    		}
    	}
    }
	//輸出決策樹
	public void printDTree(TreeNode node)
	{
		if(node.getChildren()==null){
			System.out.println("--"+node.getName());
			return;
		}
		System.out.println(node.getName());
		List<TreeNode> childs = node.getChildren();
		for (int i = 0; i < childs.size(); i++)
		{
			System.out.println(childs.get(i).getFatherAttribute());
			printDTree(childs.get(i));
		}
	}
    public static void main(String[] args) {
		// TODO Auto-generated method stub
    	MyID3 myid3 = new MyID3();
    	myid3.readARFF(new File("datafile/decisiontree/test/in/weather.nominal.arff"));
    	myid3.createDTree();
    	myid3.printDTree(root);
	}
    //讀取arff檔案,給attribute、attributevalue、data賦值
    public void readARFF(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
            	if (line.startsWith("@decision")) {
                   line = br.readLine();
                        if(line=="")
                            continue;
                        yesNo = line.split(",");
                }
            	Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                    attribute.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    attributevalue.add(al);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if(line=="")
                            continue;
                        String[] row = line.split(",");
                        data.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }
}