1. 程式人生 > >Spark2.0 協同過濾推薦

Spark2.0 協同過濾推薦

所有 隱藏 level n-1 模型 研究 win imp jet

ALS矩陣分解

http://blog.csdn.net/oucpowerman/article/details/49847979
http://www.open-open.com/lib/view/open1457672855046.html
一個的打分矩陣 A 可以用兩個小矩陣和的乘積來近似,描述一個人的喜好經常是在一個抽象的低維空間上進行的,並不需要把其喜歡的事物一一列出。再抽象一些,把人們的喜好和電影的特征都投到這個低維空間,一個人的喜好映射到了一個低維向量,一個電影的特征變成了緯度相同的向量,那麽這個人和這個電影的相似度就可以表述成這兩個向量之間的內積。
我們把打分理解成相似度,那麽“打分矩陣A(m?n)

”就可以由“用戶喜好特征矩陣U(m?k)”和“產品特征矩陣V(n?k)”的乘積。
矩陣分解過程中所用的優化方法分為兩種:交叉最小二乘法(alternative least squares)和隨機梯度下降法(stochastic gradient descent)。

參數選取

    • 分塊數:分塊是為了並行計算,默認為10。
    • 正則化參數:默認為1。
    • 秩:模型中隱藏因子的個數
    • 顯示偏好信息-false,隱式偏好信息-true,默認false(顯示)
    • alpha:只用於隱式的偏好數據,偏好值可信度底線。
    • 非負限定
    • numBlocks is the number of blocks the users and items will be
      partitioned into in order to parallelize computation (defaults to
      10).
    • rank is the number of latent factors in the model (defaults to 10).
    • maxIter is the maximum number of iterations to run (defaults to 10).
    • regParam specifies the regularization parameter in ALS (defaults to 1.0).
    • implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false
      which means using explicit feedback).
    • alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference
      observations (defaults to 1.0).
    • nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false).
 ALS als = new ALS()
          .setMaxIter(10)//最大叠代次數,設置太大發生java.lang.StackOverflowError
          .setRegParam(0.16)//正則化參數
          .setAlpha(1.0)
          .setImplicitPrefs(false)
          .setNonnegative(false)
          .setNumBlocks(10)
          .setRank(10)
          .setUserCol("userId")
          .setItemCol("movieId")
          .setRatingCol("rating");

需要註意的問題:
對於用戶和物品項ID ,基於DataFrame API 只支持integers,因此最大值限定在integers範圍內。

The DataFrame-based API for ALS currently only supports integers for 
user and item ids. Other numeric types are supported for the user and 
item id columns, but the ids must be within the integer value range.
//循環正則化參數,每次由Evaluator給出RMSError
      List<Double> RMSE=new ArrayList<Double>();//構建一個List保存所有的RMSE
      for(int i=0;i<20;i++){//進行20次循環
          double lambda=(i*5+1)*0.01;//RegParam按照0.05增加
          ALS als = new ALS()
          .setMaxIter(5)//最大叠代次數
          .setRegParam(lambda)//正則化參數
          .setUserCol("userId")
          .setItemCol("movieId")
          .setRatingCol("rating");
          ALSModel model = als.fit(training);         
          // Evaluate the model by computing the RMSE on the test data
          Dataset<Row> predictions = model.transform(test);
          //RegressionEvaluator.setMetricName可以定義四種評估器
          //"rmse" (default): root mean squared error
          //"mse": mean squared error
          //"r2": R^2^ metric 
          //"mae": mean absolute error        
          RegressionEvaluator evaluator = new RegressionEvaluator()
          .setMetricName("rmse")//RMS Error
          .setLabelCol("rating")
          .setPredictionCol("prediction");
          Double rmse = evaluator.evaluate(predictions);
          RMSE.add(rmse);
          System.out.println("RegParam "+0.01*i+" RMSE " + rmse+"\n");        
      } 
      //輸出所有結果
      for (int j = 0; j < RMSE.size(); j++) {
          Double lambda=(j*5+1)*0.01;
          System.out.println("RegParam= "+lambda+"  RMSE= " + RMSE.get(j)+"\n");    
    }
通過設計一個循環,可以研究最合適的參數,部分結果如下:
RegParam= 0.01  RMSE= 1.956
RegParam= 0.06  RMSE= 1.166
RegParam= 0.11  RMSE= 0.977
RegParam= 0.16  RMSE= 0.962//具備最小的RMSE,參數最合適
RegParam= 0.21  RMSE= 0.985
RegParam= 0.26  RMSE= 1.021
RegParam= 0.31  RMSE= 1.061
RegParam= 0.36  RMSE= 1.102
RegParam= 0.41  RMSE= 1.144
RegParam= 0.51  RMSE= 1.228
RegParam= 0.56  RMSE= 1.267
RegParam= 0.61  RMSE= 1.300
//將RegParam固定在0.16,繼續研究叠代次數的影響
輸出如下的結果,在單機環境中,叠代次數設置過大,會出現一個java.lang.StackOverflowError異常。是由於當前線程的棧滿了引起的。
numMaxIteration= 1  RMSE= 1.7325
numMaxIteration= 4  RMSE= 1.0695
numMaxIteration= 7  RMSE= 1.0563
numMaxIteration= 10  RMSE= 1.055
numMaxIteration= 13  RMSE= 1.053
numMaxIteration= 16  RMSE= 1.053
//測試Rank隱含語義個數
Rank =1  RMSErr = 1.1584
Rank =3  RMSErr = 1.1067
Rank =5  RMSErr = 0.9366
Rank =7  RMSErr = 0.9745
Rank =9  RMSErr = 0.9440
Rank =11  RMSErr = 0.9458
Rank =13  RMSErr = 0.9466
Rank =15  RMSErr = 0.9443
Rank =17  RMSErr = 0.9543

//可以用SPARK-SQL自己定義評估算法(如下面定義了一個平均絕對值誤差計算過程)
// Register the DataFrame as a SQL temporary view
predictions.createOrReplaceTempView("tmp_predictions");                                     
Dataset<Row> absDiff=spark.sql("select abs(prediction-rating) as diff from tmp_predictions");                   
absDiff.createOrReplaceTempView("tmp_absDiff");
spark.sql("select mean(diff) as absMeanDiff from tmp_absDiff").show();      

完整代碼

可以在 http://spark.apache.org/docs/latest/ml-collaborative-filtering.html找到

package my.spark.ml.practice.classification;

import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class myCollabFilter2 {  

    public static void main(String[] args) {
        SparkSession spark=SparkSession
                .builder()
                .appName("CoFilter")
                .master("local[4]")
                .config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse" )
                .getOrCreate();

        String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"
                + "data/mllib/als/sample_movielens_ratings.txt";

        //屏蔽日誌
                Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
                Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);   
        //-------------------------------1.0 準備DataFrame----------------------------
        //..javaRDD()函數將DataFrame轉換為RDD
        //然後對RDD進行Map 每一行String->Rating
        JavaRDD<Rating> ratingRDD=spark.read().textFile(path).javaRDD()
                .map(new Function<String, Rating>() {

                    @Override
                    public Rating call(String str) throws Exception {                       
                        return Rating.parseRating(str);
                    }
                });
        //System.out.println(ratingRDD.take(10).get(0).getMovieId());

        //由JavaRDD(每一行都是一個實例化的Rating對象)和Rating Class創建DataFrame
        Dataset<Row> ratings=spark.createDataFrame(ratingRDD, Rating.class);
        //ratings.show(30);

        //將數據隨機分為訓練集和測試集
        double[] weights=new double[] {0.8,0.2};
        long seed=1234;
        Dataset<Row> [] split=ratings.randomSplit(weights, seed);
        Dataset<Row> training=split[0];
        Dataset<Row> test=split[1];         

        //------------------------------2.0 ALS算法和訓練數據集,產生推薦模型-------------
        for(int rank=1;rank<20;rank++)
        {
            //定義算法
            ALS als=new ALS()
                    .setMaxIter(5)////最大叠代次數,設置太大發生java.lang.StackOverflowError
                    .setRegParam(0.16)              
                    .setUserCol("userId")               
                    .setRank(rank)
                    .setItemCol("movieId")
                    .setRatingCol("rating");
            //訓練模型
            ALSModel model=als.fit(training);
            //---------------------------3.0 模型評估:計算RMSE,均方根誤差---------------------
            Dataset<Row> predictions=model.transform(test);
            //predictions.show();
            RegressionEvaluator evaluator=new RegressionEvaluator()
                    .setMetricName("rmse")
                    .setLabelCol("rating")
                    .setPredictionCol("prediction");
            Double rmse=evaluator.evaluate(predictions);
            System.out.println("Rank =" + rank+"  RMSErr = " + rmse);               
        }       
    }
}

Spark2.0 協同過濾推薦