1. 程式人生 > >Spark2.0機器學習系列之1:基於Pipeline、交叉驗證、ParamMap的模型選擇和超引數調優

Spark2.0機器學習系列之1:基於Pipeline、交叉驗證、ParamMap的模型選擇和超引數調優

Spark中的CrossValidation

  • Spark中採用是k折交叉驗證 (k-fold cross validation)。舉個例子,例如10折交叉驗證(10-fold cross validation),將資料集分成10份,輪流將其中9份做訓練1份做驗證,10次的結果的均值作為對演算法精度的估計。
  • 10折交叉檢驗最常見,是因為通過利用大量資料集、使用不同學習技術進行的大量試驗,表明10折是獲得最好誤差估計的恰當選擇,而且也有一些理論根據可以證明這一點。但這並非最終結論,爭議仍然存在。而且似乎5折或者20折與10折所得出的結果也相差無幾。
  • 交叉檢驗常用於分析模型的泛化能力,提高模型的穩定。相對於手工探索式的引數除錯,交叉驗證更具備統計學上的意義。
  • 在Spark中,Cross Validation和ParamMap(“引數組合”的Map)結合使用。具體做法是,針對某有特定的Param組合,CrossValidator計算K (K 折交叉驗證)個評估分數的平均值。然後和其它“引數組合”CrossValidator計算結果比較,完成所有的比較後,將最優的“引數組合”挑選出來,這“最優的一組引數”將用在整個訓練資料集上重新訓練(re-fit),得到最終的Model。
  • 也就是說,通過交叉驗證,找到了最佳的”引數組合“,利用這組引數,在整個訓練集上可以訓練(fit)出一個泛化能力強,誤差相對最小的的最佳模型。
  • 很顯然,交叉驗證計算代價很高,假設有三個引數:引數alpha有3中選擇,引數beta有4種選擇,引數gamma有4中選擇,進行10折計算,那麼將進行(3×4×4)×10=480次模型訓練。

Spark documnets 原文:
(1)CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular ParamMap, CrossValidator computes the average evaluation metric for the 3 Models produced by fitting the Estimator on the 3 different (training, test) dataset pairs.
(2)After identifying the best ParamMap, CrossValidator finally re-fits the Estimator using the best ParamMap and the entire dataset.
(3)Using CrossValidator to select from a grid of parameters.Note that cross-validation over a grid of parameters is expensive. E.g., in the example below, the parameter grid has 3 values for hashingTF.numFeatures and 2 values for lr.regParam, and CrossValidator uses 2 folds. This multiplies out to (3×2)×2=12different models being trained. In realistic settings, it can be common to try many more parameters and use more folds (k=3 and k=10 are common). In other words, using CrossValidator can be very expensive. However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.

計算流程

//Spark Version 2.0
package my.spark.ml.practice;

import java.io.IOException;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

/**ALS演算法協同過濾推薦演算法
 * 使用Spark 2.0 基於Pipeline,ParamMap,CrossValidation
 * 對超引數進行調優,並進行模型選擇
 * @Peng Jiayong
 */

public class MyCrossValidation {
  public static void main(String[] args) throws IOException{
      SparkSession spark=SparkSession
              .builder()
              .appName("myCrossValidation")
              .master("local[4]")
              .getOrCreate();
    //遮蔽日誌
      Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
      Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF); 
    //載入資料
      JavaRDD<Rating> ratingsRDD = spark
              .read().textFile("/home/hadoop/spark/spark-2.0.0-bin-hadoop2.6" +
                    "/data/mllib/als/sample_movielens_ratings.txt").javaRDD()
              .map(new Function<String, Rating>() {
                  public Rating call(String str) {
                      return Rating.parseRating(str);
                  }
              });
      //將整個資料集劃分為訓練集和測試集
      //注意training集將用於Cross Validation,而test集將用於最終模型的評估
      //在traning集中,在Croos Validation時將進一步劃分為K份,每次留一份作為
      //Validation,注意區分:ratings.randomSplit()分出的Test集和K 折留
      //下驗證的那一份完全不是一個概念,也起著完全不同的作用,一定不要相混淆
      Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
      Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
      Dataset<Row> training = splits[0];
      Dataset<Row> test = splits[1];

      // Build the recommendation model using ALS on the training data
      ALS als=new ALS()
              .setMaxIter(8)
              .setRank(20).setRegParam(0.8)
              .setUserCol("userId")
              .setItemCol("movieId")
              .setRatingCol("rating")
              .setPredictionCol("predict_rating");
      /*
       * (1)秩Rank:模型中隱含因子的個數:低階近似矩陣中隱含特在個數,因子一般多一點比較好,
       * 但是會增大記憶體的開銷。因此常在訓練效果和系統開銷之間進行權衡,通常取值在10-200之間。
       * (2)最大迭代次數:執行時的迭代次數,ALS可以做到每次迭代都可以降低評級矩陣的重建誤差,
       * 一般少數次迭代便能收斂到一個比較合理的好模型。
       * 大部分情況下沒有必要進行太對多次迭代(10次左右一般就挺好了)
       * (3)正則化引數regParam:和其他機器學習演算法一樣,控制模型的過擬合情況。
       * 該值與資料大小,特徵,係數程度有關。此引數正是交叉驗證需要驗證的引數之一。
       */
      // Configure an ML pipeline, which consists of one stage
      //一般會包含多個stages
      Pipeline pipeline=new Pipeline().
              setStages(new PipelineStage[] {als});
      // We use a ParamGridBuilder to construct a grid of parameters to search over.
      ParamMap[] paramGrid=new ParamGridBuilder()
      .addGrid(als.rank(),new int[]{5,10,20})
      .addGrid(als.regParam(),new double[]{0.05,0.10,0.15,0.20,0.40,0.80})
      .build();

      // CrossValidator 需要一個Estimator,一組Estimator ParamMaps, 和一個Evaluator.
      // (1)Pipeline作為Estimator;
      // (2)定義一個RegressionEvaluator作為Evaluator,並將評估標準設定為“rmse”均方根誤差
      // (3)設定ParamMap
      // (4)設定numFolds    

      CrossValidator cv=new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new RegressionEvaluator()
              .setLabelCol("rating")
              .setPredictionCol("predict_rating")
              .setMetricName("rmse"))
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(5);

      // 執行交叉檢驗,自動選擇最佳的引數組合
      CrossValidatorModel cvModel=cv.fit(training);
      //儲存模型
      cvModel.save("/home/hadoop/spark/cvModel_als.modle");

      //System.out.println("numFolds: "+cvModel.getNumFolds());
      //Test資料集上結果評估  
      Dataset<Row> predictions=cvModel.transform(test);
      RegressionEvaluator evaluator = new RegressionEvaluator()
      .setMetricName("rmse")//RMS Error
      .setLabelCol("rating")
      .setPredictionCol("predict_rating");
      Double rmse = evaluator.evaluate(predictions);
      System.out.println("RMSE @ test dataset " + rmse);
      //Output: RMSE @ test dataset 0.943644792277118
  }   
}
備註:程式執行需要定義Rating Class
在下面連結裡可以找到:
http://spark.apache.org/docs/latest/ml-collaborative-filtering.html