1. 程式人生 > >樸素貝葉斯文字分類java實現

樸素貝葉斯文字分類java實現


import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.data.util.IoUtil;

public
class NativeBayes { /** * 預設頻率 */ private double defaultFreq = 0.1; /** * 訓練資料的比例 */ private Double trainingPercent = 0.8; private Map<String, List<String>> files_all = new HashMap<String, List<String>>(); private Map<String, List<String>> files_train = new
HashMap<String, List<String>>(); private Map<String, List<String>> files_test = new HashMap<String, List<String>>(); public NativeBayes() { } /** * 每個分類的頻率 */ private Map<String, Integer> classFreq = new HashMap<String, Integer>(); private
Map<String, Double> ClassProb = new HashMap<String, Double>(); /** * 特徵總數 */ private Set<String> WordDict = new HashSet<String>(); private Map<String, Map<String, Integer>> classFeaFreq = new HashMap<String, Map<String, Integer>>(); private Map<String, Map<String, Double>> ClassFeaProb = new HashMap<String, Map<String, Double>>(); private Map<String, Double> ClassDefaultProb = new HashMap<String, Double>(); /** * 計算準確率 * @param reallist 真實類別 * @param pridlist 預測類別 */ public void Evaluate(List<String> reallist, List<String> pridlist){ double correctNum = 0.0; for (int i = 0; i < reallist.size(); i++) { if(reallist.get(i) == pridlist.get(i)){ correctNum += 1; } } double accuracy = correctNum / reallist.size(); System.out.println("準確率為:" + accuracy); } /** * 計算精確率和召回率 * @param reallist * @param pridlist * @param classname */ public void CalPreRec(List<String> reallist, List<String> pridlist, String classname){ double correctNum = 0.0; double allNum = 0.0;//測試資料中,某個分類的文章總數 double preNum = 0.0;//測試資料中,預測為該分類的文章總數 for (int i = 0; i < reallist.size(); i++) { if(reallist.get(i) == classname){ allNum += 1; if(reallist.get(i) == pridlist.get(i)){ correctNum += 1; } } if(pridlist.get(i) == classname){ preNum += 1; } } System.out.println(classname + " 精確率(跟預測分類比較):" + correctNum / preNum + " 召回率(跟真實分類比較):" + correctNum / allNum); } /** * 用模型進行預測 */ public void PredictTestData() { List<String> reallist=new ArrayList<String>(); List<String> pridlist=new ArrayList<String>(); for (Entry<String, List<String>> entry : files_test.entrySet()) { String realclassname = entry.getKey(); List<String> files = entry.getValue(); for (String file : files) { reallist.add(realclassname); List<String> classnamelist=new ArrayList<String>(); List<Double> scorelist=new ArrayList<Double>(); for (Entry<String, Double> entry_1 : ClassProb.entrySet()) { String classname = entry_1.getKey(); //先驗概率 Double score = Math.log(entry_1.getValue()); String[] words = IoUtil.readFromFile(new File(file)).split(" "); for (String word : words) { if(!WordDict.contains(word)){ continue; } if(ClassFeaProb.get(classname).containsKey(word)){ score += Math.log(ClassFeaProb.get(classname).get(word)); }else{ score += Math.log(ClassDefaultProb.get(classname)); } } classnamelist.add(classname); scorelist.add(score); } Double maxProb = Collections.max(scorelist); int idx = scorelist.indexOf(maxProb); pridlist.add(classnamelist.get(idx)); } } Evaluate(reallist, pridlist); for (String cname : files_test.keySet()) { CalPreRec(reallist, pridlist, cname); } } /** * 模型訓練 */ public void createModel() { double sum = 0.0; for (Entry<String, Integer> entry : classFreq.entrySet()) { sum+=entry.getValue(); } for (Entry<String, Integer> entry : classFreq.entrySet()) { ClassProb.put(entry.getKey(), entry.getValue()/sum); } for (Entry<String, Map<String, Integer>> entry : classFeaFreq.entrySet()) { sum = 0.0; String classname = entry.getKey(); for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){ sum += entry_1.getValue(); } double newsum = sum + WordDict.size()*defaultFreq; Map<String, Double> feaProb = new HashMap<String, Double>(); ClassFeaProb.put(classname, feaProb); for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){ String word = entry_1.getKey(); feaProb.put(word, (entry_1.getValue() +defaultFreq) /newsum); } ClassDefaultProb.put(classname, defaultFreq/newsum); } } /** * 載入訓練資料 */ public void loadTrainData(){ for (Entry<String, List<String>> entry : files_train.entrySet()) { String classname = entry.getKey(); List<String> docs = entry.getValue(); classFreq.put(classname, docs.size()); Map<String, Integer> feaFreq = new HashMap<String, Integer>(); classFeaFreq.put(classname, feaFreq); for (String doc : docs) { String[] words = IoUtil.readFromFile(new File(doc)).split(" "); for (String word : words) { WordDict.add(word); if(feaFreq.containsKey(word)){ int num = feaFreq.get(word) + 1; feaFreq.put(word, num); }else{ feaFreq.put(word, 1); } } } } System.out.println(classFreq.size()+" 分類, " + WordDict.size()+" 特徵詞"); } /** * 將資料分為訓練資料和測試資料 * * @param dataDir */ public void splitData(String dataDir) { // 用檔名區分類別 Pattern pat = Pattern.compile("\\d+([a-z]+?)\\."); dataDir = "testdata/allfiles"; File f = new File(dataDir); File[] files = f.listFiles(); for (File file : files) { String fname = file.getName(); Matcher m = pat.matcher(fname); if (m.find()) { String cname = m.group(1); if (files_all.containsKey(cname)) { files_all.get(cname).add(file.toString()); } else { List<String> tmp = new ArrayList<String>(); tmp.add(file.toString()); files_all.put(cname, tmp); } } else { System.out.println("err: " + file); } } System.out.println("統計資料:"); for (Entry<String, List<String>> entry : files_all.entrySet()) { String cname = entry.getKey(); List<String> value = entry.getValue(); // System.out.println(cname + " : " + value.size()); List<String> train = new ArrayList<String>(); List<String> test = new ArrayList<String>(); for (String str : value) { if (Math.random() <= trainingPercent) {// 80%用來訓練 , 20%測試 train.add(str); } else { test.add(str); } } files_train.put(cname, train); files_test.put(cname, test); } System.out.println("所有檔案數:"); printStatistics(files_all); System.out.println("訓練檔案數:"); printStatistics(files_train); System.out.println("測試檔案數:"); printStatistics(files_test); } /** * 列印統計資訊 * * @param m */ public void printStatistics(Map<String, List<String>> m) { for (Entry<String, List<String>> entry : m.entrySet()) { String cname = entry.getKey(); List<String> value = entry.getValue(); System.out.println(cname + " : " + value.size()); } System.out.println("--------------------------------"); } public static void main(String[] args) { NativeBayes bayes = new NativeBayes(); bayes.splitData(null); bayes.loadTrainData(); bayes.createModel(); bayes.PredictTestData(); } } 所有檔案數: sports : 1018 auto : 1020 business : 1028 -------------------------------- 訓練檔案數: sports : 791 auto : 812 business : 808 -------------------------------- 測試檔案數: sports : 227 auto : 208 business : 220 -------------------------------- 分類, 39613 特徵詞 準確率為:0.9801526717557252 sports 精確率(跟預測分類比較):0.9956140350877193 召回率(跟真實分類比較):1.0 auto 精確率(跟預測分類比較):0.9579439252336449 召回率(跟真實分類比較):0.9855769230769231 business 精確率(跟預測分類比較):0.9859154929577465 召回率(跟真實分類比較):0.9545454545454546 統計資料: 所有檔案數: sports : 1018 auto : 1020 business : 1028 -------------------------------- 訓練檔案數: sports : 827 auto : 833 business : 825 -------------------------------- 測試檔案數: sports : 191 auto : 187 business : 203 -------------------------------- 分類, 39907 特徵詞 準確率為:0.9759036144578314 sports 精確率(跟預測分類比較):0.9894736842105263 召回率(跟真實分類比較):0.9842931937172775 auto 精確率(跟預測分類比較):0.9836956521739131 召回率(跟真實分類比較):0.9679144385026738 business 精確率(跟預測分類比較):0.9565217391304348 召回率(跟真實分類比較):0.9753694581280788