1. 程式人生 > >文字分類——KNN演算法

文字分類——KNN演算法

上一篇文章已經描述了樸素貝葉斯演算法newgroup的分類實現,這篇文章採用KNN演算法實現newgroup的分類。

文中程式碼參考:http://blog.csdn.net/yangliuy/article/details/7401142

1、KNN演算法描述

對於KNN演算法,前面有一篇文章介紹其思想,但是按個事例採用的模擬的數值資料。本文將採用KNN進行文字分類。演算法步驟如下:

(1)文字預處理,向量化,根據特徵詞的TF*IDF值計算 (上一篇文章已經處理)

(2)當新文字到達後,根據特徵詞計算新文字的向量

(3)在訓練文字中選出與新文字最相近的K個文字,相似度用向量夾角的餘弦值度量。

         注:K的值目前沒有好的辦法確定,只有根據實驗來調整K的值

(4)在新文字的K個相似文字中,依此計算每個類的權重,每個類的權重等於K個文字中屬於該類的訓練樣本與測試樣本的相似度之和。

(5)比較類的權重,將文字分到權重最大那個類別中

2、KNN演算法實現

KNN演算法的實現要注意

(1)用TreeMap<String,TreeMap<String,Double>>儲存測試集和訓練集
(2)注意要以"類目_檔名"作為每個檔案的key,才能避免同名不同內容的檔案出現

package com.datamine.NaiveBayes;

import java.io.*;
import java.util.*;

/**
 * KNN演算法的實現類,本程式用向量夾角餘弦計算相識度
 * @author Administrator
 */
public class KNNClassifier {

	/**
	 * 用knn演算法對測試文件集分類,讀取測試樣例和訓練樣例集
	 * @param trainFiles 訓練樣例的所有向量構成的檔案
	 * @param testFiles  測試樣例的所有向量構成的檔案
	 * @param knnResultFile KNN分類結果檔案路徑
	 * @throws Exception 
	 */
	private void doProcess(String trainFiles, String testFiles,
			String knnResultFile) throws Exception {
		/*
		 * 首先讀取訓練樣本和測試樣本,用map<String,map<word,TF>>儲存測試集和訓練集,注意訓練樣本的類目資訊也得儲存
		 * 然後遍歷測試樣本,對於每一個測試樣本去計算它與所有訓練樣本的相識度,相識度儲存到map<String,double>有序map中
		 * 然後取錢K個樣本,針對這k個樣本來給它們所屬的類目計算權重得分,對屬於同一個類目的權重求和進而得到最大得分的類目
		 * 就可以判斷測試樣例屬於該類目下,K值可以反覆測試,找到分類準確率最高的那個值
		 * 注意:
		 *  1、要以"類目_檔名"作為每個檔案的key,才能避免同名不同內容的檔案出現
		 *  2、注意設定JM引數,否則會出現JAVA Heap溢位錯誤
		 *  3、本程式用向量夾角餘弦計算相識度
		 */
		File trainSample = new File(trainFiles);
		BufferedReader trainSampleBR = new BufferedReader(new FileReader(trainSample));
		String line;
		String[] lineSplitBlock;
		//trainFileNameWordTFMap<類名_檔名,map<特徵詞,特徵權重>>
		Map<String,TreeMap<String,Double>> trainFileNameWordTFMap = new TreeMap<String, TreeMap<String,Double>>();
		//trainWordTFMap<特徵詞,特徵權重>
		TreeMap<String,Double> trainWordTFMap = new TreeMap<String, Double>();
		while((line = trainSampleBR.readLine()) != null){
			lineSplitBlock = line.split(" ");
			trainWordTFMap.clear();
			for(int i =2 ;i<lineSplitBlock.length;i = i+2){
				trainWordTFMap.put(lineSplitBlock[i], Double.valueOf(lineSplitBlock[i+1]));
			}
			TreeMap<String,Double> tempMap = new TreeMap<String, Double>();
			tempMap.putAll(trainWordTFMap);
			trainFileNameWordTFMap.put(lineSplitBlock[0]+"_"+lineSplitBlock[1], tempMap);
		}
		trainSampleBR.close();
		
		File testSample = new File(testFiles);
		BufferedReader testSampleBR = new BufferedReader(new FileReader(testSample));
		Map<String,Map<String,Double>>  testFileNameWordTFMap = new TreeMap<String, Map<String,Double>>();
		Map<String,Double> testWordTFMap = new TreeMap<String, Double>();
		while((line = testSampleBR.readLine()) != null){
			lineSplitBlock = line.split(" ");
			testWordTFMap.clear();
			for(int i =2;i<lineSplitBlock.length;i = i+2){
				testWordTFMap.put(lineSplitBlock[i], Double.valueOf(lineSplitBlock[i+1]));
			}
			TreeMap<String,Double> tempMap = new TreeMap<String, Double>();
			tempMap.putAll(testWordTFMap);
			testFileNameWordTFMap.put(lineSplitBlock[0]+"_"+lineSplitBlock[1], tempMap);
		}
		testSampleBR.close();
		
		//下面遍歷每一個測試樣例計算所有訓練樣本的距離,做分類
		String classifyResult;
		FileWriter knnClassifyResultWriter = new FileWriter(knnResultFile);
		Set<Map.Entry<String, Map<String,Double>>> testFileNameWordTFMapSet = testFileNameWordTFMap.entrySet();
		
		for(Iterator<Map.Entry<String, Map<String,Double>>> it = testFileNameWordTFMapSet.iterator();it.hasNext();){
			
			Map.Entry<String, Map<String,Double>> me = it.next();
			
			classifyResult = knnComputeCate(me.getKey(),me.getValue(),trainFileNameWordTFMap);
			
			knnClassifyResultWriter.append(me.getKey()+" "+classifyResult+"\n");
			knnClassifyResultWriter.flush();
		}
		knnClassifyResultWriter.close();
	}
	
	
	/**
	 * 對於每一個測試樣本去計算它與所有訓練樣本的向量夾角餘弦相識度
	 * 相識度儲存入map<String,double>有序map中,然後取前k個樣本
	 * 針對這k個樣本來給他們所屬的類目計算權重得分,對屬於同一個類目的權重求和進而得到最大得分類目
	 * k值可以反覆測試,找到分類準確率最高的那個值
	 * @param testFileName 測試檔名 "類別名_檔名"
	 * @param testWordTFMap 測試檔案向量  map<特徵詞,特徵權重>
	 * @param trainFileNameWordTFMap 訓練樣本<類目_檔名,向量>
	 * @return K個鄰居權重得分最大的類目
	 */
	private String knnComputeCate(String testFileName, Map<String, Double> testWordTFMap, 
			Map<String, TreeMap<String, Double>> trainFileNameWordTFMap) {

		//<類目_檔名,距離> 後面需要將該HashMap按照value排序
		HashMap<String,Double> simMap = new HashMap<String, Double>();
		double similarity;
		Set<Map.Entry<String, TreeMap<String,Double>>> trainFileNameTFMapSet = trainFileNameWordTFMap.entrySet();
		for(Iterator<Map.Entry<String, TreeMap<String,Double>>> it = trainFileNameTFMapSet.iterator();it.hasNext();){
			
			Map.Entry<String, TreeMap<String,Double>> me = it.next();
			similarity = computeSim(testWordTFMap,me.getValue());
			simMap.put(me.getKey(), similarity);
		}
		
		//下面對simMap按照value降序排序
		ByValueComparator bvc = new ByValueComparator(simMap);
		TreeMap<String,Double> sortedSimMap = new TreeMap<String, Double>(bvc);
		sortedSimMap.putAll(simMap);
		
		//在disMap中取前K個最近的訓練樣本對其類別計算距離之和,K的值通過反覆試驗而得
		Map<String,Double> cateSimMap = new TreeMap<String, Double>(); //k個最近訓練樣本所屬類目的距離之和
		double K = 20;
		double count = 0;
		double tempSim ;
		
		Set<Map.Entry<String, Double>> simMapSet = sortedSimMap.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = simMapSet.iterator();it.hasNext();){
			
			Map.Entry<String, Double> me = it.next();
			count++;
			String categoryName = me.getKey().split("_")[0];
			if(cateSimMap.containsKey(categoryName)){
				tempSim = cateSimMap.get(categoryName);
				cateSimMap.put(categoryName, tempSim+me.getValue());
			}else
				cateSimMap.put(categoryName, me.getValue());
			
			if(count>K)
				break;
		}
		//下面到cateSimMap裡面吧sim最大的那個類目名稱找出來
		double maxSim = 0;
		String bestCate = null;
		Set<Map.Entry<String, Double>> cateSimMapSet = cateSimMap.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = cateSimMapSet.iterator();it.hasNext();){
			
			Map.Entry<String, Double> me = it.next();
			if(me.getValue() > maxSim){
				bestCate = me.getKey();
				maxSim = me.getValue();
			}
		}
		return bestCate;
	}

	/**
	 * 計算測試樣本向量和訓練樣本向量的相識度
	 * sim(D1,D2)=(D1*D2)/(|D1|*|D2|)
	 * 例:D1(a 30;b 20;c 20;d 10) D2(a 40;c 30;d 20; e 10)
	 * D1*D2 = 30*40 + 20*0 + 20*30 + 10*20 + 0*10 = 2000
	 * |D1| = sqrt(30*30+20*20+20*20+10*10) = sqrt(1800)
	 * |D2| = sqrt(40*40+30*30+20*20+10*10) = sqrt(3000)
	 * sim = 0.86;
	 * @param testWordTFMap  當前測試檔案的<單詞,權重>向量
	 * @param trainWordTFMap 當前訓練樣本<單詞,權重>向量
	 * @return 向量之間的相識度,以向量夾角餘弦計算
	 */
	private double computeSim(Map<String, Double> testWordTFMap,
			TreeMap<String, Double> trainWordTFMap) {
		
		// mul = test*train  testAbs = |test|  trainAbs = |train|
		double mul = 0,testAbs = 0, trainAbs = 0;
		Set<Map.Entry<String, Double>> testWordTFMapSet = testWordTFMap.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = testWordTFMapSet.iterator();it.hasNext();){
			
			Map.Entry<String, Double> me = it.next();
			if(trainWordTFMap.containsKey(me.getKey())){
				mul += me.getValue()*trainWordTFMap.get(me.getKey());
			}
			testAbs += me.getValue()*me.getValue();
		}
		testAbs = Math.sqrt(testAbs);
		
		Set<Map.Entry<String, Double>> trainWordTFMapSet = trainWordTFMap.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = trainWordTFMapSet.iterator();it.hasNext();){
			
			Map.Entry<String, Double> me = it.next();
			trainAbs += me.getValue()*me.getValue();
		}
		trainAbs = Math.sqrt(trainAbs);
		
		return mul / (testAbs * trainAbs);
	}


	/**
	 * 根據knn演算法分類結果檔案生成正確類目檔案,而正確率和混淆矩陣的計算可以複用貝葉斯演算法中的方法
	 * @param knnResultFile 分類結果檔案   <"目錄名_檔名",分類結果>
	 * @param knnRightFile 分類正確類目檔案  <"目錄名_檔名",正確結果>
	 * @throws IOException 
	 */
	private void createRightFile(String knnResultFile, String knnRightFile) throws IOException {
		
		String rightCate;
		FileReader fileR = new FileReader(knnResultFile);
		FileWriter knnRightWriter = new FileWriter(new File(knnRightFile));
		BufferedReader fileBR = new BufferedReader(fileR);
		String line;
		String lineBlock[];
		while((line = fileBR.readLine()) != null){
			
			lineBlock = line.split(" ");
			rightCate = lineBlock[0].split("_")[0];
			knnRightWriter.append(lineBlock[0]+" "+rightCate+"\n");
		}
		knnRightWriter.flush();
		fileBR.close();
		knnRightWriter.close();
	}
	
	
	public static void main(String[] args) throws Exception {
	
		//wordMap是所有屬性詞的詞典<單詞,在所有文件中出現的次數>
		double[] accuracyOfEveryExp = new double[10];
		double accuracyAvg,sum=0;
		KNNClassifier knnClassifier = new KNNClassifier();
		NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();
		Map<String,Double> wordMap = new TreeMap<String, Double>();
		Map<String,Double> IDFPerWordMap = new TreeMap<String, Double>();
		ComputeWordsVector computeWV = new ComputeWordsVector();
		
		wordMap = computeWV.countWords("E:\\DataMiningSample\\processedSample", wordMap);
		IDFPerWordMap = computeWV.computeIDF("E:\\DataMiningSample\\processedSampleOnlySpecial", wordMap);
		//IDFPerWordMap=null;
		computeWV.printWordMap(wordMap);
		
		// 首先生成KNN演算法10次試驗需要的文件TF矩陣檔案
		for (int i = 0; i < 1; i++) {
			
			computeWV.computeTFMultiIDF("E:/DataMiningSample/processedSampleOnlySpecial", 0.9, i, IDFPerWordMap, wordMap);
			
			String trainFiles = "E:\\DataMiningSample\\docVector\\wordTFIDFMapTrainSample"+i;
			String testFiles = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+i;
			
			String knnResultFile = "E:/DataMiningSample/docVector/KNNClassifyResult"+i;
			String knnRightFile = "E:/DataMiningSample/docVector/KNNClassifyRight"+i;
			
			knnClassifier.doProcess(trainFiles,testFiles,knnResultFile);
			knnClassifier.createRightFile(knnResultFile,knnRightFile);
			
			//計算準確率和混淆矩陣使用樸素貝葉斯中的方法
			accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(knnRightFile, knnResultFile);
			sum += accuracyOfEveryExp[i];
			System.out.println("The accuracy for KNN Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);
		}
		//accuracyAvg = sum / 10;
		//System.out.println("The average accuracy for KNN Classifier in all Exps is :" + accuracyAvg);
	}
	
	//對hashMap按照value做排序 降序
	static class ByValueComparator implements Comparator<Object>{

		HashMap<String,Double> base_map;
		
		public ByValueComparator(HashMap<String,Double> disMap) {
			this.base_map = disMap;
		}
		
		@Override
		public int compare(Object o1, Object o2) {
			
			String arg0 = o1.toString();
			String arg1 = o2.toString();
			if(!base_map.containsKey(arg0) || !base_map.containsKey(arg1)){
				return 0;
			}
			if(base_map.get(arg0) < base_map.get(arg1))
				return 1;
			else if(base_map.get(arg0) == base_map.get(arg1))
				return 0;
			else
				return -1;
		}
		
	}
	
}

3、KNN分類結果

這裡只列出一個結果