1. 程式人生 > >【學習排序】 Learning to Rank 中Listwise關於ListNet演算法講解及實現

【學習排序】 Learning to Rank 中Listwise關於ListNet演算法講解及實現

    程式碼如下:
package listNet_xiuzhang;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.InputStreamReader;

public class listNet {
	
	//檔案總行數(標記數)
	private static int sumLabel;                   
	//特徵值 46個 (標號1-46)
	private static double feature[][] = new double[100000][48];                
	//特徵值權重 46個 (標號1-46)
	private static double weight [] = new double[48];
	//相關度 其值有0-2三個級別 從1開始記錄
	private static int label [] = new int[1000000];
	//查詢id 從1開始記錄
	private static int qid [] = new int[1000000];
	//每個Qid的doc數量
	private static int doc_ofQid[] = new int[100000]; 

	private static int ITER_NUM=30;     //迭代次數
	private static int weidu=46;        //特徵數
	private static int qid_Num=0;       //Qid數量
	private static int tempQid=-1;      //臨時Qid數
	private static int tempDoc=0;       //臨時doc數
	
	/**  
	 * 函式功能 讀取檔案
	 * 引數 String filePath 檔案路徑
	 */
	public static void ReadTxtFile(String filePath) {
        try {
        	String encoding="GBK";
        	File file=new File(filePath);
        	if(file.isFile() && file.exists()) { //判斷檔案是否存在
        		InputStreamReader read = new InputStreamReader(new FileInputStream(file), encoding); 
                BufferedReader bufferedReader = new BufferedReader(read);
                String lineTxt = null;
                sumLabel =1; //初始化從1記錄
                //按行讀取資料並分解資料
                while((lineTxt = bufferedReader.readLine()) != null) {
                	String str = null;
                	int lengthLine = lineTxt.length();
                	//獲取資料 字串空格分隔
                	String arrays[] = lineTxt.split(" ");
                	for(int i=0; i<arrays.length; i++) {
                		//獲取每行樣本的Label值
                		if(i==0) {
                			label[sumLabel] = Integer.parseInt(arrays[0]);
                		} 
                		else if(i>=weidu+2){ //讀取至#跳出 0-label 1-qid 2:47-特徵
                			continue;
                		}
                		else {
                			String subArrays[] = arrays[i].split(":"); //特徵:特徵值
                			if(i==1) { //獲取qid		
                				//判斷是否是新的Qid
                				if(tempQid != Integer.parseInt(subArrays[1])) { 
                					if(tempQid != -1){ //不是第一次出現新Qid
                						//賦值上一個為qid_Num對應的tempDoc個文件
                						doc_ofQid[qid_Num]=tempDoc;    
                						tempDoc=0;
                					}
                					//當tempQid不等於當前qid時下標加1 
                					//相等則直接跳至Doc加1直到不等
                					qid_Num++;
                					tempQid=Integer.parseInt(subArrays[1]);    					
                				}
                				tempDoc++; //新的文件 
                				qid[sumLabel] = Integer.parseInt(subArrays[1]);
                			} 
                			else { //獲取46維特徵值
                				int number = Integer.parseInt(subArrays[0]); //判斷特徵
                				double value = Double.parseDouble(subArrays[1]);
                				feature[sumLabel][number] = value; //number陣列標號:1-46
                			}
                		}
                	}
                	sumLabel++;
                }
                doc_ofQid[qid_Num]=tempDoc;
                read.close();
        	} else {
        		System.out.println("找不到指定的檔案\n");
        	}
        } catch (Exception e) {
            System.out.println("讀取檔案內容出錯");
            e.printStackTrace();
        }
    }

	/**
	 * 學習排序
	 * 訓練模型得到46維權重
	 */
	public static void LearningToRank() {
		
		//變數
		double index [] = new double[1000000];
		double tao [] = new double[1000000];
		double yita=0.00003;
		//初始化
		for(int i=0;i<weidu+2;i++) { //從1到136為權重,0和137無用
			weight[i] = (double) 1.0; //權重初值
		}
		System.out.println("training...");				
		//計算權重 學習演算法
		for(int iter = 0; iter<ITER_NUM; iter++) //迭代ITER_NUM次
		{ 
			System.out.println("---迭代次數:"+iter);
			int now_doc=0; //全域性文件索引
			for(int i=1; i<=qid_Num; i++) //總樣qid數  相當於兩層迴圈T和m 
			{ 
				double delta_w[] = new double[weidu+2]; //46個梯度組成的向量
				int doc_of_i=doc_ofQid[i]; //該Qid的文件數
				//得分f(w),一個QID有多個文件,一個文件為一個分,所以一個i對應一個分數陣列
				double fw[] = new double[doc_of_i+2];
				
				/* 第一步 算得分陣列fw fin */
				for(int k=1;k<=doc_of_i;k++) { //初始化
					fw[k]=0.0;
				}
				for(int k=1;k<=doc_of_i;k++) { //每個文件的得分
					for(int p=1;p<=weidu;p++) {
						fw[k]=fw[k]+weight[p]*feature[now_doc+k][p]; //算出這個文件的分數
					}
				}
				
				/*
				 * 第二步  算梯度delta_w向量
				 * a=Σp*x,a是向量  
				 * b=Σexpf(x),b是數字
				 * c=expf(x)*x,c是向量
				 * 最終結果delta_w是向量
				 */
				double[] a=new double[weidu+2],c=new double[weidu+2];
				for(int k=0;k<weidu+2;k++){a[k]=0.0;} //初始化
				for(int k=0;k<weidu+2;k++){c[k]=0.0;} //初始化
				double b=0.0;
				//算a:----
				for(int k=1; k<=doc_of_i; k++) {
					double p=1.0; //先不topK
					double[] temp=new double[48];
					for(int q=1;q<=weidu;q++) {
						//算P: ----第q個向量排XX的概率是多少
						//分母:
						double fenmu=0.0;
						for(int m=1;m<=doc_of_i;m++) {
							fenmu=fenmu+Math.exp(fw[m]); //所有文件得分
						}
						//top-1  exp(s1) / exp(s1)+exp(s2)+..+exp(sn)
						for(int m=1;m<=doc_of_i;m++) {
							p=p*(Math.exp(fw[m])/fenmu);
						}
						//算積
						temp[q]=temp[q]+p*feature[now_doc+k][q];
					}
					for(int q=1; q<=weidu; q++){			
						a[q]=a[q]+temp[q];
					}	
				} //End a
				//算b:---- fin.
				for(int k=1; k<=doc_of_i; k++){
					b=b+Math.exp(fw[k]);
				}
				//算c:----
				for(int k=1; k<=doc_of_i; k++){
					double[] temp=new double[weidu+2];
					for(int q=1; q<=weidu; q++){			
						temp[q]=temp[q]+Math.exp(fw[k])*feature[now_doc+k][q];
					}
					for(int q=1; q<=weidu; q++){			
						c[q]=c[q]+temp[q];
					}	
				}
				//算梯度:delta_x=-a+1/b*c
				for(int q=1; q<=weidu; q++){
					delta_w[q]= (-1)*a[q] + ((1.0/b)*c[q]);
				}
				//**********
				
				/* 第三步 更新權重 fin. */
				for(int k=1; k<=weidu; k++){
					weight[k]=weight[k]-yita*delta_w[k];
				}
				now_doc=now_doc+doc_of_i; //更新當前文件索引
			}
		} //End 迭代次數
		
		//輸出權重
		for(int i=1;i<=weidu;i++) //從1到136為權重,0和137無用
		{
			System.out.println(i+"wei:"+weight[i]);
		}
	}
	
	/**
	 * 輸出權重到檔案fileModel
	 * @param fileModel
	 */
	public static void WriteFileModel(String fileModel) {
		//輸出權重到檔案
		try {
			System.out.println("write start.總行數:"+sumLabel);
			FileWriter fileWriter = new FileWriter(fileModel);
			//寫資料
			fileWriter.write("## ListNet");
			fileWriter.write("\r\n");
			fileWriter.write("## Epochs = "+ITER_NUM);
			fileWriter.write("\r\n");
			fileWriter.write("## No. of features = 46");
			fileWriter.write("\r\n");
			fileWriter.write("1 2 3 4 5 6 7 8 9 10 ...  39 40 41 42 43 44 45 46");
			fileWriter.write("\r\n");
			fileWriter.write("0");
			fileWriter.write("\r\n");
			for(int k=0; k<weidu; k++){
				fileWriter.write("0 "+k+" "+weight[k+1]);
				fileWriter.write("\r\n");
			}
			fileWriter.close();
			System.out.println("write fin.");
		} catch(Exception e) {
			System.out.println("寫檔案內容出錯");
            e.printStackTrace();
		}
	}
	
	/**
	 * 預測排序
	 * 正規應對test.txt檔案進行打分排序
	 * 但我們是在Hadoop實現該打分排序步驟 此函式僅測試train.txt打分
	 */
	public static void PredictRank(String fileScore) {
		//輸出得分
		try {
			System.out.println("write start.總行數:"+sumLabel);
			String encoding = "GBK";
			FileWriter fileWriter = new FileWriter(fileScore);
			//寫資料
			for(int k=1; k<sumLabel; k++){
				double score=0.0;
				for(int j=1;j<=weidu;j++){
					score=score+weight[j]*feature[k][j];
				}
				fileWriter.write("qid:"+qid[k]+" score:"+score+" label:"+label[k]);
				fileWriter.write("\r\n");
			}	
			fileWriter.close();
			System.out.println("write fin.");	
		} catch(Exception e) {
			System.out.println("寫檔案內容出錯");
            e.printStackTrace();
		}
	}
	
	/**
	 * 主函式
	 */
	public static void main(String args[]) {
		String fileInput = "Fold1\\train.txt";       //訓練
		String fileModel = "model_weight.txt";       //輸出權重模型
		String fileScore = "score_listNet.txt";      //輸出得分
		//第1步 讀取檔案並解析資料
		System.out.println("read...");
		ReadTxtFile(fileInput);
		System.out.println("read and write well.");
		//第2步 排序計算
		LearningToRank();
		//第3步 輸出模型
		WriteFileModel(fileModel);
		//第4步 打分預測排序
		PredictRank(fileScore);
	  }
	
	/*
	 * End
	 */
	
}