SPARK官方例項:兩種方法實現隨機森林模型(ML/MLlib)
阿新 • • 發佈:2019-01-30
在spark2.0以上版本中,存在兩種對機器學習演算法的實現庫MLlib與ML,比如隨機森林:org.apache.spark.mllib.tree.RandomForest
和org.apache.spark.ml.classification.RandomForestClassificationModel
兩種庫對應的使用方法也不同,Mllib是RDD-based API,
ML是基於ML pipeline的API與dataframe的資料結構。
詳見http://spark.apache.org/docs/latest/ml-guide.html
所以官方例項也是有很大區別的,下面分別給出了原始碼和註釋:
MLlib的模型實現
// scalastyle:off println package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils // $example off$ object RandomForestClassificationExample { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("RandomForestClassificationExample") val sc = new SparkContext(conf) // $example on$ // Load and parse the data file. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Split the data into training and test sets (30% held out for testing) val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // Train a RandomForest model. // Empty categoricalFeaturesInfo indicates all features are continuous. val numClasses = 2 val categoricalFeaturesInfo = Map[Int, Int]() val numTrees = 3 // Use more in practice. val featureSubsetStrategy = "auto" // Let the algorithm choose. val impurity = "gini" val maxDepth = 4 val maxBins = 32 val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) // Evaluate model on test instances and compute test error val labelAndPreds = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification forest model:\n" + model.toDebugString) // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") // $example off$ } } // scalastyle:on println
ML的模型實現
// scalastyle:off println package org.apache.spark.examples.ml // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ import org.apache.spark.sql.SparkSession object RandomForestClassifierExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder .appName("RandomForestClassifierExample") .getOrCreate() // $example on$ // Load and parse the data file, converting it to a DataFrame. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. val labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data) // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) .fit(data) // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a RandomForest model. val rf = new RandomForestClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setNumTrees(10) // Convert indexed labels back to original labels. val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. val predictions = model.transform(testData) // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5) // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] println("Learned classification forest model:\n" + rfModel.toDebugString) // $example off$ spark.stop() } } // scalastyle:on println
TIPS:
想看http://spark.apache.org/docs裡面示例程式碼的全部嗎?一種方法是去github上找,另一種方法是進spark的安裝目錄,所有的原始碼都在 spark/examples/src/main/scala/
裡面,
如ML的演算法scala實現:spark/examples/src/main/scala/org/apache/spark/examples/ml
MLlib的演算法scala實現:spark/examples/src/main/scala/org/apache/spark/examples/mllib