1. 程式人生 > >機器學習_13.樸素貝葉斯

機器學習_13.樸素貝葉斯

樸素貝葉斯

樸素貝葉斯試講連續取值的輸入對映為離散取值的輸出的演算法,用於解決分類問題。基本思想在與分析待分類樣本出現每個輸出類別中的後驗概率,並取最大後驗概率的類別作為分類的輸出。從模型最優化的角度看,樸素貝葉斯分類是平均意義上預測能力最優的模型,即使期望風險最小化。其中,期望風險是風險函式的數學期望,即預測的誤差。

已知聯合概率分佈為:

P(Y)被稱為類先驗概率,P(X|Y)為類似然概率,似然概率往往是m^n數量級的計算量級別。

如果此時使用了條件獨立性假設,假設所有屬性相互獨立,這樣可以保證在對分類結果的影響中,每個屬性是獨立的,就可以將類條件概率變成屬性條件概率的乘積:

事實上獨立性假設是比較苛刻的條件,現實生活中多種事物之間的因果關係是複雜的,但對數學計算上的極大簡化足以忽略效能上的稍微降低。在後驗概率的求解中,邊界概率P(X)是可以忽略的。在先驗概率和似然概率都已知後,後驗概率就可以求解了,其最大值對應的類也是樸素貝葉斯分類器的輸出。

                       

此外,我們需要假定連續型屬性資料滿足正態分佈,並根據每個類別下的訓練資料得到其均值和方差。

上面談到的期望風險為單詞預測中聯合概率分佈的數學期望。

如果將訓練中的錯誤資料作為誤差,期望風險就變成了1-P(X|Y),後驗概率最大化和期望風險最小化就可以成功以一個趨勢方向變化了,可以說此時兩者在說同一個標準。

樸素貝葉斯分類中,也許會出現樣本不足使得判斷出現偏差導致的分類錯誤,通過學習極客實踐的文件,我們知道了可以新增“拉普拉斯平滑”這一步驟,即:在分子上新增一個較小的修正量,在分母上則新增這個修正量與分類數目的乘積,當訓練及的資料量較大時,就可以忽略修正量對先驗概率的影響。較小的修正量這裡起到的作用是保證在滿足概率基本性質的條件下,避免零概率對分類結果的影響。

說到上述各種假設中與現實情況的誤差,可能會導致後驗概率的計算實際上不精確,但在分類上中大部分會指向同一個結果。也許兩個單獨的屬性之間依賴關係是很強的,但許許多多的屬性放在一起,可能會出現抵消現象。影響樸素貝葉斯的分類的事所有屬性之間的依賴關係在不同類別上的分佈,而不僅僅是依賴關係本身。在這個基礎上,也出現了“半樸素貝葉斯分類器”的方法。

樸素貝葉斯分類器應用廣泛,如針對關鍵詞判斷垃圾郵件、判斷殭屍粉絲還是活躍賬戶,樸素貝葉斯在二元分類中很有效果。

情景:

程式碼:

import numpy
from sklearn.naive_bayes import GaussianNB
 
 
def test_gaussian_nb():
    X = numpy.array([
        [6, 180, 12],
        [5.92, 190, 11],
        [5.58, 170, 12],
        [5.92, 165, 10],
        [5, 100, 6],
        [5.5, 150, 8],
        [5.42, 130, 7],
        [5.75, 150, 9],
    ])
 
    Y = numpy.array([1, 1, 1, 1, 0, 0, 0, 0])
 
    gnb = GaussianNB()
    gnb.fit(X, Y)
 
    test = numpy.array([6, 130, 8]).reshape(1, -1)
    result = gnb.predict(test)
    print(result)
    print("\n")
 
    result = gnb.predict_proba(test)
    print(result[0][0])
    print(result[0][1])
 
 
if __name__ == '__main__':
    test_gaussian_nb()

借鑑:https://blog.csdn.net/grafx/article/details/77823503

打球例項-Java實現

weather.nominal.arff

#存放做決策的屬性,一般是或否
@decision
yes,no
@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no

求出各個值的概率:

trainresult.arff

@decision P(yes) {0.6428571428571429}
@decision P(no) {0.35714285714285715}
@data
P(outlook=sunny|yes),0.2222222222222222
P(outlook=sunny|no),0.6
P(outlook=overcast|yes),0.4444444444444444
P(outlook=overcast|no),0.0
P(outlook=rainy|yes),0.3333333333333333
P(outlook=rainy|no),0.4
P(temperature=hot|yes),0.2222222222222222
P(temperature=hot|no),0.4
P(temperature=mild|yes),0.4444444444444444
P(temperature=mild|no),0.4
P(temperature=cool|yes),0.3333333333333333
P(temperature=cool|no),0.2
P(humidity=high|yes),0.3333333333333333
P(humidity=high|no),0.8
P(humidity=normal|yes),0.6666666666666666
P(humidity=normal|no),0.2
P(windy=TRUE|yes),0.3333333333333333
P(windy=TRUE|no),0.6
P(windy=FALSE|yes),0.6666666666666666
P(windy=FALSE|no),0.4

main.java

package sequence.machinelearning.naivebayes.bayesdemo;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;

public class Main {

	public static void main(String[] args) throws IOException {
		// TODO Auto-generated method stub
		Main m=new Main();
		m.stringBufferDemo();
		//m.fileWriter("D:/test.txt");
		m.readF1();
	}
	
	public void fileWriter(String fileName) throws IOException{
        //建立一個FileWriter物件
        FileWriter fw = new FileWriter(fileName);
        //遍歷clist集合寫入到fileName中
        for (int i=0;i<10;i++){
            fw.write("第"+i+"行----");
            fw.write("\n");
        }
        //重新整理緩衝區
        fw.flush();
        //關閉檔案流物件
        fw.close();
    }

	
	
	/**
    * 利用StringBuffer寫檔案
    * 該方法可以設定使用何種編碼,有效解決中文問題。
    * @throws IOException
    */
   
   public void stringBufferDemo() throws IOException
   {
       String src="datafile/naivebayes/train/out/result.arff";
       delfile(src);
       File file=new File(src);
       if(file.exists())
           file.createNewFile();
       FileOutputStream out=new FileOutputStream(file,true);
       for(int i=0;i<10;i++)
       {
           StringBuffer sb=new StringBuffer();
           sb.append("這是第"+i+"行 \n");//如果不加"/n"則不能實現換行。
           System.out.print(sb.toString());
           
           out.write(sb.toString().getBytes("utf-8"));
       }
       out.close();
   }
   public void delfile(String filepath){
	   File file=new File(filepath);   
	       if(file.exists())   
	      {   
	           //file.createNewFile(); 
			   file.delete();   
	       }    

   }
	public void readF1() throws IOException {      
		
		//String filePath="scripts/clustering/canopy/canopy.dat";
		String filePath="datafile/naivebayes/train/out/result";
		BufferedReader br = new BufferedReader(new InputStreamReader(
       new FileInputStream(filePath)));
       for (String line = br.readLine(); line != null; line = br.readLine()) {
           if(line.length()==0||"".equals(line))continue;
       	String[] str=line.split(",");   
       	
       	   
       }
       br.close();
       
   }


}

Test.java

package sequence.machinelearning.naivebayes.bayesdemo;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class Test {

	private static Map<String,Double> cmap=new HashMap<String,Double>();
	private static Map<String,Double> pmap=new HashMap<String,Double>();
    public static final String patternString = "@decision(.*)[{](.*?)[}]";
	public BigDecimal getProbability(String[] line,String decision){
		
		String ckey="P("+decision+")";
		//獲取P(yes)的概率
		BigDecimal result=new BigDecimal(cmap.get(ckey));
			for(int j=0;j<line.length;j++){
				String attval=line[j].toString();
				String pkey="P("+Train.lisatt.get(j)+"="+attval+"|"+decision+")";
				//取得P(outlook=sunny|yes)的概率相
				BigDecimal pi=new BigDecimal(pmap.get(pkey));
				result=result.multiply(pi);
			}
		//System.out.println(arraytoString(line)+" 為"+decision+"的參考數值是:"+result.toString().substring(0,5));
		return result;
	}
	public void printResult(){
		for(int i=0;i<Train.listdata.size();i++){
			String[] line=Train.listdata.get(i);
			BigDecimal p=new BigDecimal(0);
			int index=-1;
			for(int j=0;j<Train.sort.size();j++){
				BigDecimal pnext=getProbability(line,Train.sort.get(j));
				if(p.compareTo(pnext)==-1){
					p=pnext;
					index=j;
				}
			}
			System.out.println(arraytoString(line)+"   判斷的結果是:"+Train.sort.get(index)+"	      --參考數值是:"+p.toString().substring(0,5));
		}
	}
	
	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Train train=new Train();
		//讀取測試集
		train.readARFF(new File("datafile/naivebayes/test/in/test.arff"));
		Test test=new Test();
		//讀取訓練結果
		test.readResult(new File("datafile/naivebayes/train/out/trainresult.arff"));
		test.printResult();
	}
	//陣列轉字串
	public String arraytoString(String[] line){
		
		StringBuffer sb = new StringBuffer();
		for(int i = 0; i < line.length; i++){
		 sb. append(line[i]+",");
		}
        String newStr = sb.toString();
        return newStr.substring(0, newStr.length()-1);
	}
    //讀取arff檔案,給attribute、attributevalue、data賦值
    public void readResult(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
            	Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                	String[] values = matcher.group(2).split(",");
                    Double val=Double.valueOf(values[0]);
                    cmap.put(matcher.group(1).trim(), val);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if(line=="")
                            continue;
                        String[] row = line.split(",");
                        Double val=Double.valueOf(row[1]);
                        pmap.put(row[0], val);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }

}

Train.java

package sequence.machinelearning.naivebayes.bayesdemo;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.LinkedList;
/**
 * 案例:http://www.cnblogs.com/zhangchaoyang/articles/2586402.html
 * @author Jamas
 *  也參考了這篇文章:http://www.cnblogs.com/leoo2sk/archive/2010/09/17/naive-bayesian-classifier.html
 */
public class Train {

    public static LinkedList<String> lisatt = new LinkedList<String>(); // 儲存屬性的名稱:outlook,temperature,humidity,windy
    public static LinkedList<ArrayList<String>> lisvals = new LinkedList<ArrayList<String>>(); //outlook:sunny,overcast,rainy 儲存每個屬性的取值,屬性的特徵
    public static LinkedList<String[]> listdata = new LinkedList<String[]>();; // 原始資料
   
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
    //儲存分類,比如,是,否。再比如:檢測SNS社群中不真實賬號,是真實使用者還是殭屍使用者
    public static LinkedList<String> sort=new LinkedList<String>();
	
    //計算P(F1|C)P(F2|C)...P(Fn|C)P(C),並儲存為文字檔案 
    /**
     * 為了避免零頻問題,對每個計數加1,只要數量足夠大,加1是可以忽略的
     * @throws IOException
     */
    public void CountProbility() throws IOException{
    	
        String src="datafile/naivebayes/train/out/trainresult.arff";
        delfile(src);
        File file=new File(src);
        if(file.exists())
            file.createNewFile();
        FileOutputStream out=new FileOutputStream(file,true);
        Map<String,Integer> map=new HashMap<String,Integer>();
    	//先計算判定結果的概率,儲存為檔案
    	for(int i=0;i<sort.size();i++){
            //第一個for對取出sort,第二個for對data中的sort進行計數
    		//避免零頻問題,對各項計數加1
    		Integer sum=1;
            String sortname=sort.get(i);
            Double probability=0.0;
            
            for(int j=0;j<listdata.size();j++){
    			String[] line=listdata.get(j);
    			if(line[line.length-1].equals(sortname)){
    				sum=sum+1;
    			}
     	    }
    		map.put(sortname, sum);
    		probability=Double.valueOf(sum)/Double.valueOf(listdata.size());
    		//寫入檔案
    		StringBuffer sb=new StringBuffer();
            sb.append("@decision P("+sortname+") {"+probability.toString()+"}\n");//如果不加"/n"則不能實現換行。
            System.out.print(sb.toString());
            
            out.write(sb.toString().getBytes("utf-8"));
    	}
    	out.write("@data\n".getBytes("utf-8"));
    	System.out.print("@data\n");
    	//先計算判定結果的概率,儲存為檔案
    	//out.close(); //到最後寫完的時候再關閉
    	//分別統計P(F1|C)P(F2|C)...P(Fn|C)的個數,參考:http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html
    	 //對屬性進行迴圈
        for(int i=0;i<lisatt.size();i++){
        	
        	String attname=lisatt.get(i);
        	List<String> lisval=lisvals.get(i);
        	//對屬性的特徵進行迴圈
        	for(int j=0;j<lisval.size();j++){
        		String attval=lisval.get(j);
        		//先取出sort(yes 還是no情況)
        		for(int n=0;n<sort.size();n++){
        			//避免零頻問題,對各項計數加1
        			Integer sum=1;
                    String sortname=sort.get(n);
                    Double probability=0.0;
                    
                    //取出資料集進行for
                    for(int k=0;k<listdata.size();k++){
            			String[] line=listdata.get(k);
            			if(line[line.length-1].equals(sortname)&&line[i].equals(attval)){
            				sum=sum+1;
            			}
             	    }
                    
            		probability=Double.valueOf(sum)/Double.valueOf(map.get(sortname));
            		//寫入檔案
            		StringBuffer sb=new StringBuffer();
                    sb.append("P("+attname+"="+attval+"|"+sortname+"),"+probability+"\n");//如果不加"/n"則不能實現換行。
                    System.out.print(sb.toString());
                    out.write(sb.toString().getBytes("utf-8"));
        		}
        		
        	}
        }
        out.close();
    }
    
    
    
    //讀取arff檔案,給attribute、attributevalue、data賦值
    public void readARFF(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
            	if (line.startsWith("@decision")) {
                   line = br.readLine();
                        if(line=="")
                            continue;
                        String[] type = line.split(",");
                        for(int i=0;i<type.length;i++){
                        	sort.add(type[i].trim());
                        }
                }
            	Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                	lisatt.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    lisvals.add(al);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if(line=="")
                            continue;
                        String[] row = line.split(",");
                        listdata.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }
	public static void main(String[] args) throws IOException {
		// TODO Auto-generated method stub
		Train train=new Train();
		train.readARFF(new File("datafile/naivebayes/train/in/weather.nominal.arff"));
		train.CountProbility();
		
	}
	public void delfile(String filepath){
		   File file=new File(filepath);   
		       if(file.exists())   
		      {   
		           //file.createNewFile(); 
				   file.delete();   
		       }    

	   }
}

package-info.java


package sequence.machinelearning.naivebayes.bayesdemo;

借鑑:https://blog.csdn.net/jameshadoop/article/details/35276083