決策樹演算法原理及JAVA實現(ID3)
阿新 • • 發佈:2018-12-25
package sequence.machinelearning.decisiontree.myid3; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.LinkedList; public class MyID3 { private static LinkedList<String> attribute = new LinkedList<String>(); // 儲存屬性的名稱 private static LinkedList<ArrayList<String>> attributevalue = new LinkedList<ArrayList<String>>(); // 儲存每個屬性的取值 private static LinkedList<String[]> data = new LinkedList<String[]>();; // 原始資料 public static final String patternString = "@attribute(.*)[{](.*?)[}]"; public static String[] yesNo; public static TreeNode root; /** * * @param lines 傳入要分析的資料集 * @param index 哪個屬性?attribute的index */ public Double getGain(LinkedList<String[]> lines,int index){ Double gain=-1.0; List<Double> li=new ArrayList<Double>(); //統計Yes No的次數 for(int i=0;i<yesNo.length;i++){ Double sum=0.0; for(int j=0;j<lines.size();j++){ String[] line=lines.get(j); //data為結構化資料,如果資料最後一列==yes,sum+1 if(line[line.length-1].equals(yesNo[i])){ sum=sum+1; } } li.add(sum); } //計算Entropy(S)計算Entropy(S) 見參考書《機器學習 》Tom.Mitchell著 第3.4.1.2節 Double entropyS=TheMath.getEntropy(lines.size(), li); //下面計算gain List<String> la=attributevalue.get(index); List<Point> lasv=new ArrayList<Point>(); for(int n=0;n<la.size();n++){ String attvalue=la.get(n); //統計Yes No的次數 List<Double> lisub=new ArrayList<Double>();//如:sunny 是yes時發生的次數,是no發生的次數 Double Sv=0.0;//公式3.4中的Sv 見參考書《機器學習(Tom.Mitchell著)》 for(int i=0;i<yesNo.length;i++){ Double sum=0.0; for(int j=0;j<lines.size();j++){ String[] line=lines.get(j); //data為結構化資料,如果資料最後一列==yes,sum+1 if(line[index].equals(attvalue)&&line[line.length-1].equals(yesNo[i])){ sum=sum+1; } } Sv=Sv+sum;//計算總數 lisub.add(sum); } //計算Entropy(S) 見參考書《機器學習(Tom.Mitchell著)》 Double entropySv=TheMath.getEntropy(Sv.intValue(), lisub); // Point p=new Point(); p.setSv(Sv); p.setEntropySv(entropySv); lasv.add(p); } gain=TheMath.getGain(entropyS,lines.size(),lasv); return gain; } //尋找最大的資訊增益,將最大的屬性定為當前節點,並返回該屬性所在list的位置和gain值 public Maxgain getMaxGain(LinkedList<String[]> lines){ if(lines==null||lines.size()<=0){ return null; } Maxgain maxgain = new Maxgain(); Double maxvalue=0.0; int maxindex=-1; for(int i=0;i<attribute.size();i++){ Double tmp=getGain(lines,i); if(maxvalue< tmp){ maxvalue=tmp; maxindex=i; } } maxgain.setMaxgain(maxvalue); maxgain.setMaxindex(maxindex); return maxgain; } //剪取陣列 public LinkedList<String[]> filterLines(LinkedList<String[]> lines, String attvalue, int index){ LinkedList<String[]> newlines=new LinkedList<String[]>(); for(int i=0;i<lines.size();i++){ String[] line=lines.get(i); if(line[index].equals(attvalue)){ newlines.add(line); } } return newlines; } public void createDTree(){ root=new TreeNode(); Maxgain maxgain=getMaxGain(data); if(maxgain==null){ System.out.println("沒有資料集,請檢查!"); } int maxKey=maxgain.getMaxindex(); String nodename=attribute.get(maxKey); root.setName(nodename); root.setLiatts(attributevalue.get(maxKey)); insertNode(data,root,maxKey); } /** * * @param lines 傳入的資料集,作為新的遞迴資料集 * @param node 深入此節點 * @param index 屬性位置 */ public void insertNode(LinkedList<String[]> lines,TreeNode node,int index){ List<String> liatts=node.getLiatts(); for(int i=0;i<liatts.size();i++){ String attname=liatts.get(i); LinkedList<String[]> newlines=filterLines(lines,attname,index); if(newlines.size()<=0){ System.out.println("出現異常,迴圈結束"); return; } Maxgain maxgain=getMaxGain(newlines); double gain=maxgain.getMaxgain(); Integer maxKey=maxgain.getMaxindex(); //不等於0繼續遞迴,等於0說明是葉子節點,結束遞迴。 if(gain!=0){ TreeNode subnode=new TreeNode(); subnode.setParent(node); subnode.setFatherAttribute(attname); String nodename=attribute.get(maxKey); subnode.setName(nodename); subnode.setLiatts(attributevalue.get(maxKey)); node.addChild(subnode); //不等於0,繼續遞迴 insertNode(newlines,subnode,maxKey); }else{ TreeNode subnode=new TreeNode(); subnode.setParent(node); subnode.setFatherAttribute(attname); //葉子節點是yes還是no?取新行中最後一個必是其名稱,因為只有完全是yes,或完全是no的情況下才會是葉子節點 String[] line=newlines.get(0); String nodename=line[line.length-1]; subnode.setName(nodename); node.addChild(subnode); } } } //輸出決策樹 public void printDTree(TreeNode node) { if(node.getChildren()==null){ System.out.println("--"+node.getName()); return; } System.out.println(node.getName()); List<TreeNode> childs = node.getChildren(); for (int i = 0; i < childs.size(); i++) { System.out.println(childs.get(i).getFatherAttribute()); printDTree(childs.get(i)); } } public static void main(String[] args) { // TODO Auto-generated method stub MyID3 myid3 = new MyID3(); myid3.readARFF(new File("datafile/decisiontree/test/in/weather.nominal.arff")); myid3.createDTree(); myid3.printDTree(root); } //讀取arff檔案,給attribute、attributevalue、data賦值 public void readARFF(File file) { try { FileReader fr = new FileReader(file); BufferedReader br = new BufferedReader(fr); String line; Pattern pattern = Pattern.compile(patternString); while ((line = br.readLine()) != null) { if (line.startsWith("@decision")) { line = br.readLine(); if(line=="") continue; yesNo = line.split(","); } Matcher matcher = pattern.matcher(line); if (matcher.find()) { attribute.add(matcher.group(1).trim()); String[] values = matcher.group(2).split(","); ArrayList<String> al = new ArrayList<String>(values.length); for (String value : values) { al.add(value.trim()); } attributevalue.add(al); } else if (line.startsWith("@data")) { while ((line = br.readLine()) != null) { if(line=="") continue; String[] row = line.split(","); data.add(row); } } else { continue; } } br.close(); } catch (IOException e1) { e1.printStackTrace(); } } }