1. 程式人生 > >Spark MLlib中ALS交替最小二乘法推薦演算法的使用

Spark MLlib中ALS交替最小二乘法推薦演算法的使用

本文首發於我的個人部落格QIMING.INFO,轉載請帶上鍊接及署名。

ALS(Alternating Least Square),交替最小二乘法。在機器學習中,特指使用最小二乘法的一種協同推薦演算法。本文通過程式碼來演示用spark執行ALS演算法的一個小例子。

演算法簡介

ALS演算法通過觀察到的所有使用者給商品的打分,來推斷每個使用者的喜好並向用戶推薦適合的商品。

其原理簡單說就是假設使用者評分矩陣是使用者特徵矩陣乘以物品特徵矩陣得到的,即:A(m*n)=U(m*k)*V(k*n),然後得到一個評分矩陣。具體原理請自行查閱,本文主要為使用。

通常,呼叫ALS演算法進行訓練時有4個重要引數,分別是ratings

rankiterations,和lambda

  • ratings指使用者提供的訓練資料,它包括使用者id集、商品id集以及相應的打分集;
  • rank表示隱含因素的數量,即特徵的數量,也就是分解矩陣的k值。
  • iterations表示最大迭代次數;
  • lambda表示正則因子,可省略,預設為0.01。

執行步驟

資料說明

資料格式為:使用者id,物品id,評分

[xuqm@cu01 ML_Data]$ cat input/test.data 
1,1,5.0
1,2,1.0
1,3,5.0
1,4,1.0
2,1,5.0
2,2,1.0
2,3,5.0
2,4,1.0
3,1,1.0
3,2,5.0
3
,3,1.0 3,4,5.0 4,1,1.0 4,2,5.0 4,3,1.0 4,4,5.0

程式碼及說明


package nwpuhpc.antirisk.ml

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.{SparkConf, SparkContext}

object ALSTest {

  // 構建Spark物件
  val conf = new SparkConf().setAppName("ALSTest"
) val sc = new SparkContext(conf) Logger.getRootLogger.setLevel(Level.WARN) // 讀取樣本資料 val data = sc.textFile("/home/xuqm/ML_Data/input/test.data") val ratings = data.map(_.split(',') match { case Array(user, item, rate) => Rating(user.toInt, item.toInt, rate.toDouble) }) // 拆分成訓練集和測試集 val dataParts = ratings.randomSplit(Array(0.8, 0.2)) val trainingRDD = dataParts(0).cache() val testRDD = dataParts(1) // 建立ALS交替最小二乘演算法模型並訓練 val rank = 10 val numIterations = 20 val model = ALS.train(trainingRDD, rank, numIterations, 0.01) // 取出測試集中的使用者id和商品id val usersProducts = testRDD.map { case Rating(user, product, rate) => (user, product) } // 用訓練好的模型預測測試集的結果 val predictions = model.predict(usersProducts).map { case Rating(user, product, rate) => ((user, product), rate) } val ratesAndPreds = testRDD.map { case Rating(user, product, rate) => ((user, product), rate) }.join(predictions) // 輸出誤差 val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => val err = (r1 - r2) err * err }.mean() println("Mean Squared Error = " + MSE) // 列印輸出預測值 println("User" + "\t" + "Products" + "\t" + "Rate" + "\t" + "Prediction") ratesAndPreds.collect.foreach( rating => { println(rating._1._1 + "\t" + rating._1._2 + "\t" + rating._2._1 + "\t" + rating._2._2) } ) }

結果展示

可以看出,誤差不是很大。