初學者入門-用Spark ML來處理超大資料
還是轉譯KDNuggets的文章。微軟的Dmitry Petrov介紹的如何用Spark ML來處理超過記憶體大小的資料。原文的 Link
這裡側重的是資料的大小遠遠超過單機的記憶體大小。原來這樣的分析都是要用分散式的系統(比如hadoop)上來實現,而這篇文章裡介紹的是單機如何通過Spark來實現分析。不過自己做了很多的migration,所以就算是原創啦。
本文所要介紹的案例的目的是要建立一個預測模型來基於帖子的標題和內容來預測一個帖子的標籤(Tag)。處於簡化程式碼的目的,文章裡會把這兩field組合成一個文字列來處理,而不是分別處理。(譯者注:很明顯,標題裡的文字對於預測標籤的權重應該更大,所以現實工作中,我們應該是分別對待這兩個列)。
很容易理解這個預測模型對於stackoverflow.com這樣的網站的價值。使用者輸入一個問題,網站會自動的給出標籤的建議。假定我們需要儘可能多的正確的標籤,這樣使用者可以刪掉那些不相關的標籤。基於這樣的假定,我們就可以使用recall來作為檢驗模型好壞的最重要的依據了。
首先是要找這樣的一個數據,文章裡用的是在aXive上的stackflow的posts.xml檔案,連結是同時作者也提供了一個小檔案來給大家做練習,連結在https://www.dropbox.com/s/n2skgloqoadpa30/Posts.small.xml?dl=0,(需要注意的是,是國內訪問不了這兩個網站,所以我把第二個小檔案放到雲盤裡面供下載,地址:http://pan.baidu.com/s/1jGJFtQI,
下面要做的事情就是配置Spark環境。原文的例子是跑在Spark 1.5.1的環境下,我實際用的是1.5.2的配合Hadoop 2.6。具體的安裝步驟參考另外一個博文:hadoop叢集的搭建指令碼及構思(N):一個簡化的Hadoop+Spark on Yarn叢集快速搭建,
原文是在spark-shell裡面的一堆基於scala的命令。出於一個軟體工程師的偏好,我把這些命令放到了一個Eclipse的scala的專案裡來實現,又碰到一堆坑。下面直接上程式碼,需要修改的地方都用註釋來解釋了。
然後Export成一個jar包,用spark-submit放進去跑了。結果還碰到一個詭異的問題,不過已經解決了。見其他的博文。
//import scala.xml._
// Spark data manipulation libraries
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark._
// Spark machine learning libraries
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.ml.Pipeline
object postsClassifier {
def main(args: Array[String]){
val conf = new SparkConf().setAppName("BinaryClassifier"); //need to create SparkConf for SparkContext
val sc = new SparkContext(conf); //need to change the sc in shell to SparkContext in IDE
val fileName = "hdfs://Master1:9000/xml/Posts.xml"
val textFile = sc.textFile(fileName)
val postsXml = textFile.map(_.trim).
filter(!_.startsWith("<?xml version=")).
filter(_ != "<posts>").
filter(_ != "</posts>")
val postsRDD = postsXml.map { s =>
val xml = XML.loadString(s)
val id = (xml \ "@Id").text
val tags = (xml \ "@Tags").text
val title = (xml \ "@Title").text
val body = (xml \ "@Body").text
val bodyPlain = ("<\\S+>".r).replaceAllIn(body, " ")
val text = (title + " " + bodyPlain).replaceAll("\n",
" ").replaceAll("( )+", " ");
Row(id, tags, text)
}
val schemaString = "Id Tags Text"
val schema = StructType(
schemaString.split(" ").map(fieldName =>
StructField(fieldName, StringType, true)))
val sqlContext = new SQLContext(sc) //need to change the sqlContext in shell to SQLContext in IDE
val postsDf = sqlContext.createDataFrame(postsRDD, schema)
postsDf.show()
val targetTag = "java"
val myudf: (String => Double) = (str: String) =>
{if (str.contains(targetTag)) 1.0 else 0.0}
val sqlfunc = udf(myudf)
val postsLabeled = postsDf.withColumn("Label", sqlfunc(col("Tags")) )
val positive = postsLabeled.filter("Label > 0.0") //something is wrong here, need to check the DataFrame's filter method documents
val negative = postsLabeled.filter("Label < 1.0") //need to enclose the whole express with "", not just a ' in one side as the original codes
val positiveTrain = positive.sample(false, 0.9)
val negativeTrain = negative.sample(false, 0.9)
val training = positiveTrain.unionAll(negativeTrain)
val negativeTrainTmp = negativeTrain
.withColumnRenamed("Label", "Flag").select("Id", "Flag") //need to enclose the whole express with "", not just a ' in one side as the original codes
val negativeTest = negative.join(negativeTrainTmp, negative("Id") === negativeTrainTmp("Id"), "leftouter") //need to change double == to triple ===
.filter("Flag is null")
.select(negative("Id"), negative("Tags"), negative("Text"), negative("Label")) //need to add dataframe name to all column names
val positiveTrainTmp = positiveTrain
.withColumnRenamed("Label", "Flag")
.select("Id", "Flag")
val positiveTest = positive.join( positiveTrainTmp, positive("Id") === positiveTrainTmp("Id"), "leftouter") //need to change double == to triple ===
.filter("Flag is null")
.select(positive("Id"), positive("Tags"), positive("Text"), positive("Label")) //need to add dataframe name to all column names
val testing = negativeTest.unionAll(positiveTest)
val numFeatures = 64000
val numEpochs = 30
val regParam = 0.02
val tokenizer = new Tokenizer().setInputCol("Text")
.setOutputCol("Words")
val hashingTF = new org.apache.spark.ml.feature.HashingTF()
.setNumFeatures(numFeatures)
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("Features")
val lr = new LogisticRegression().setMaxIter(numEpochs)
.setRegParam(regParam).setFeaturesCol("Features")
.setLabelCol("Label").setRawPredictionCol("Score")
.setPredictionCol("Prediction")
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))
val model = pipeline.fit(training)
val testTitle =
"Easiest way to merge a release into one JAR file"
//val tBoby = """Is there a tool or script which easily merges a bunch of href="http:/en.wikipedia.org/wiki/JAR_%28file_format %29" JAR files into one JAR file? A bonus would be to easily set the main-file manifest and make it executable. I would like to run it with something like: As far as I can tell, it has no dependencies which indicates that it shouldn't be an easy single-file tool, but the downloaded ZIP file contains a lot of libraries."""
val testText = testTitle + """Is there a tool or script which easily merges a bunch of href="http:/en.wikipedia.org/wiki/JAR_%28file_format %29" JAR files into one JAR file? A bonus would be to easily set the main-file manifest and make it executable. I would like to run it with something like: As far as I can tell, it has no dependencies which indicates that it shouldn't be an easy single-file tool, but the downloaded ZIP file contains a lot of libraries."""
val testDF = sqlContext
.createDataFrame(Seq( (99.0, testText)))
.toDF("Label", "Text")
val result = model.transform(testDF)
val prediction = result.collect()(0)(6)
.asInstanceOf[Double]
printf("Prediction: "+ prediction)
val testingResult = model.transform(testing)
val testingResultScores = testingResult
.select("Prediction", "Label").rdd
.map(r => (r(0).asInstanceOf[Double], r(1)
.asInstanceOf[Double]))
val bc =
new BinaryClassificationMetrics(testingResultScores)
val roc = bc.areaUnderROC
printf("Area under the ROC:" + roc)
}
}
<--End-->