1. 程式人生 > >貝葉斯演算法Java實現

貝葉斯演算法Java實現

前言:樸素貝葉斯分類演算法是一種基於貝葉斯定理的簡單概率分類演算法。貝葉斯分類的基礎是概率推理,就是在各種條件的存在不確定,僅知其出現概率的情況下,如何完成推理和決策任務。概率推理是與確定性推理相對應的。而樸素貝葉斯分類器是基於獨立假設的,即假設樣本每個特徵與其他特徵都不相關。

樸素貝葉斯分類器依靠精確的自然概率模型,在有監督學習的樣本集中能獲取得非常好的分類效果。在許多實際應用中,樸素貝葉斯模型引數估計使用最大似然估計方法,換言之樸素貝葉斯模型能工作並沒有用到貝葉斯概率或者任何貝葉斯模型。

儘管是帶著這些樸素思想和過於簡單化的假設,但樸素貝葉斯分類器在很多複雜的現實情形中仍能夠取得相當好的效果。

package Bayes;

import java.util.ArrayList;  
import java.util.HashMap;  
import java.util.Map;  
import java.math.BigDecimal;
public class Bayes {  

    //將訓練集按巡邏集合的最後一個值進行分類  
    Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){  
        Map<String, ArrayList<ArrayList<String>>> map = new
HashMap<String, ArrayList<ArrayList<String>>>(); ArrayList<String> t = null; String c = ""; for (int i = 0; i < datas.size(); i++) { t = datas.get(i); c = t.get(t.size() - 1); if (map.containsKey(c)) { map.get
(c).add(t); } else { ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>(); nt.add(t); map.put(c, nt); } } return map; } //在訓練資料的基礎上預測測試元組的類別 ,testT的各個屬性在結果集裡面出現的概率相乘最高的,即是結果 public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) { Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas); //將訓練集元素劃分儲存在資料裡 Object classes[] = doc.keySet().toArray(); double maxP = 0.00; int maxPIndex = -1; //testT的各個屬性在結果集裡面出現的概率相乘最高的,即使結果集 for (int i = 0; i < doc.size(); i++) { String c = classes[i].toString(); ArrayList<ArrayList<String>> d = doc.get(c); BigDecimal b1 = new BigDecimal(Double.toString(d.size())); BigDecimal b2 = new BigDecimal(Double.toString(datas.size())); //b1除以b2得到一個精度為3的雙浮點數 double pOfC = b1.divide(b2,3,BigDecimal.ROUND_HALF_UP).doubleValue(); for (int j = 0; j < testT.size(); j++) { double pv = this.pOfV(d, testT.get(j), j); BigDecimal b3 = new BigDecimal(Double.toString(pOfC)); BigDecimal b4 = new BigDecimal(Double.toString(pv)); //b3乘以b4得到一個浮點數 pOfC=b3.multiply(b4).doubleValue(); } if(pOfC > maxP){ maxP = pOfC; maxPIndex = i; } } return classes[maxPIndex].toString(); } // 計算指定屬性到訓練集出現的頻率 private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) { double p = 0.00; int count = 0; int total = d.size(); for (int i = 0; i < total; i++) { if(d.get(i).get(index).equals(value)){ count++; } } BigDecimal b1 = new BigDecimal(Double.toString(count)); BigDecimal b2 = new BigDecimal(Double.toString(total)); //b1除以b2得到一個精度為3的雙浮點數 p = b1.divide(b2,3,BigDecimal.ROUND_HALF_UP).doubleValue(); return p; } }
package Bayes;

import java.io.BufferedReader;  
import java.io.IOException;  
import java.io.InputStreamReader;  
import java.util.ArrayList;  

public class TestBayes {  

    //讀取測試元組
    public ArrayList<String> readTestData() throws IOException{  
        ArrayList<String> candAttr = new ArrayList<String>();  
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
        String str = "";  
        while (!(str = reader.readLine()).equals("")) {
            //string分析器
            String[] tokenizer = str.split(" ");
            for(int i=0;i<tokenizer.length;i++){
                candAttr.add(tokenizer[i]);
            } 
        }  
        return candAttr;  
    }  

    //讀取訓練集
    public ArrayList<ArrayList<String>> readData() throws IOException {  
        ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();  
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
        String str = "";  
        while (!(str = reader.readLine()).equals("")) {  
            String[] tokenizer = str.split(" ");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            } 
            datas.add(s);  
        }  
        return datas;  
    }  

    public static void main(String[] args) {  
        TestBayes tb = new TestBayes();  
        ArrayList<ArrayList<String>> datas = null;  
        ArrayList<String> testT = null;  
        Bayes bayes = new Bayes();  
        try {  
            System.out.println("請輸入訓練資料");  
            datas = tb.readData();  
            while (true) {  
                System.out.println("請輸入測試元組");  
                testT = tb.readTestData();  
                String c = bayes.predictClass(datas, testT);  
                System.out.println("The class is: " + c);  
            }  
        } catch (IOException e) {  
            e.printStackTrace();  
        }  
    }  
}  

測試結果:

請輸入訓練資料
youth high no fair no  
youth high no excellent no  
middle_aged high no fair yes  
senior medium no fair yes  
senior low yes fair yes  
senior low yes excellent no  
middle_aged low yes excellent yes  
youth medium no fair no  
youth low yes fair yes  
senior medium yes fair yes  
youth medium yes excellent yes  
middle_aged medium no excellent yes  
middle_aged high yes fair yes  
senior medium no excellent no