1. 程式人生 > >樸素貝葉斯演算法Java 實現

樸素貝葉斯演算法Java 實現

對於樸素貝葉斯演算法相信做資料探勘和推薦系統的小夥們都耳熟能詳了,演算法原理我就不囉嗦了。我主要想通過java程式碼實現樸素貝葉斯演算法,思想:

1. 用javabean +Arraylist 對於訓練資料儲存

2. 對於樣本資料訓練

具體的程式碼如下:

package NB;
/**
 * 訓練樣本的屬性 javaBean
 *
 */
public class JavaBean {
  int age;
  String income;
  String student;
  String credit_rating;
  String buys_computer;
 public JavaBean(){
	 
 }
public JavaBean(int age,String income,String student,String credit_rating,String buys_computer){
	this.age=age;
	this.income=income;
	this.student=student;
	this.credit_rating=credit_rating;
	this.buys_computer=buys_computer;
}
  
  
public int getAge() {
	return age;
}
public void setAge(int age) {
	this.age = age;
}
public String getIncome() {
	return income;
}
public void setIncome(String income) {
	this.income = income;
}
public String getStudent() {
	return student;
}
public void setStudent(String student) {
	this.student = student;
}
public String getCredit_rating() {
	return credit_rating;
}
public void setCredit_rating(String credit_rating) {
	this.credit_rating = credit_rating;
}
public String getBuys_computer() {
	return buys_computer;
}
public void setBuys_computer(String buys_computer) {
	this.buys_computer = buys_computer;
}



@Override
public String toString() {
	return "JavaBean [age=" + age + ", income=" + income + ", student="
			+ student + ", credit_rating=" + credit_rating + ", buys_computer="
			+ buys_computer + "]";
}

  
  
  
  
}
演算法實現的部分:
package NB;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;

public class TestNB {

	/**data_length
	 * 演算法的思想
	 */
	public static  ArrayList<JavaBean> list = new ArrayList<JavaBean>();;
	static int data_length=0;
	public static void main(String[] args) {
		// 1.讀取資料,放入list容器中
		File file = new File("E://test.txt");
		txt2String(file);
		//資料測試樣本
		testData(25,"Medium","Yes","Fair");
	}
    // 讀取樣本資料
	public static void txt2String(File file) {
		
		try {
			BufferedReader br = new BufferedReader(new FileReader(file));// 構造一個BufferedReader類來讀取檔案
			String s = null;
			while ((s = br.readLine()) != null) {// 使用readLine方法,一次讀一行
				data_length++; 
				splitt(s);
			}
			br.close();
		} catch (Exception e) {
			e.printStackTrace();
		}
		
	}
	// 存入ArrayList中
	  public static void splitt(String str){
		   
	        String strr = str.trim();
	        String[] abc = strr.split("[\\p{Space}]+");
	        int age=Integer.parseInt(abc[0]);
	        JavaBean bean=new JavaBean(age, abc[1], abc[2], abc[3], abc[4]);
	        list.add(bean);		
	       
	       
	    }
	  // 訓練樣本,測試
	  public static void testData(int age,String a,String b,String c){
		  //訓練樣本  
		  int number_yes=0;
		  int bumber_no=0;
		  
		 // age情況 個數
		  int num_age_yes=0;
		  int num_age_no=0;
		  // income 
		  int num_income_yes=0;
		  int num_income_no=0;
		  // student 
		  int num_student_yes=0;
		  int num_stdent_no=0;
		  //credit
		  int num_credit_yes=0;
		  int num_credit_no=0;
		  
		  //遍歷List 獲得資料
		  for(int i=0;i<list.size();i++){
		    JavaBean bb=list.get(i);
		    if(bb.getBuys_computer().equals("Yes")){ //Yes
		    	number_yes++;
	            if(bb.getIncome().equals(a)){//income
	            	num_income_yes++;
	            }
		    	if(bb.getStudent().equals(b)){//student
		    		num_student_yes++;
		    	}
		    	if(bb.getCredit_rating().equals(c)){//credit
		    		num_credit_yes++;
		    	}
		    	if(bb.getAge()==age){//age
		    		num_age_yes++;
		    	}
		    	
		    	
		    }else {//No
		    	bumber_no++;
		    	if(bb.getIncome().equals(a)){//income
	            	num_income_no++;
	            }
		    	if(bb.getStudent().equals(b)){//student
		    		num_stdent_no++;
		    	}
		    	if(bb.getCredit_rating().equals(c)){//credit
		    		num_credit_no++;
		    	}
		    	if(bb.getAge()==age){//age
		    		num_age_no++;
		    	}
		    	
			}  
		  }
		  
		    System.out.println("購買的歷史個數:"+number_yes);
		    System.out.println("不買的歷史個數:"+bumber_no);
		    
		    System.out.println("購買+age:"+num_age_yes);
		    System.out.println("不買+age:"+num_age_no);
		    
		    System.out.println("購買+income:"+num_income_yes);
		    System.out.println("不買+income:"+num_income_no);
		    
		    System.out.println("購買+stundent:"+num_student_yes);
		    System.out.println("不買+student:"+num_stdent_no);
		    
		    System.out.println("購買+credit:"+num_credit_yes);
		    System.out.println("不買+credit:"+num_credit_no);
		    
		    //// 概率判斷
		    double buy_yes=number_yes*1.0/data_length; // 買的概率
			double buy_no=bumber_no*1.0/data_length; //  不買的概率
		    System.out.println("訓練資料中買的概率:"+buy_yes);
		    System.out.println("訓練資料中不買的概率:"+buy_no);
			/// 未知使用者的判斷
		    double nb_buy_yes=(1.0*num_age_yes/number_yes)*(1.0*num_income_yes/number_yes)*(1.0*num_student_yes/number_yes)*(1.0*num_credit_yes/number_yes)*buy_yes;       
		    double nb_buy_no=(1.0*num_age_no/bumber_no)*(1.0*num_income_no/bumber_no)*(1.0*num_stdent_no/bumber_no)*(1.0*num_credit_no/bumber_no)*buy_no;       
		    System.out.println("新使用者買的概率:"+nb_buy_yes);
		    System.out.println("新使用者不買的概率:"+nb_buy_no);
		    if(nb_buy_yes>nb_buy_no){
		    	System.out.println("新使用者買的概率大");
		    }else {
		    	System.out.println("新使用者不買的概率大");
			}    
	  }	  
}

對於樣本資料:
25  High    No  Fair       No
25  High    No  Excellent  No
33  High    No  Fair       Yes
41  Medium  No  Fair       Yes     
41  Low     Yes Fair       Yes
41  Low     Yes Excellent  No
33  Low     Yes Excellent  Yes
25  Medium  No  Fair       No
25  Low     Yes Fair       Yes
41  Medium  Yes Fair       Yes
25  Medium  Yes Excellent  Yes
33  Medium  No  Excellent  Yes
33  High    Yes Fair       Yes
41  Medium  No  Excellent  No

對於未知使用者的資料得出的結果:
購買的歷史個數:9
不買的歷史個數:5
購買+age:2
不買+age:3
購買+income:4
不買+income:2
購買+stundent:6
不買+student:1
購買+credit:6
不買+credit:2
訓練資料中買的概率:0.6428571428571429
訓練資料中不買的概率:0.35714285714285715
新使用者買的概率:0.028218694885361547
新使用者不買的概率:0.006857142857142858
新使用者買的概率大