【spark】採用MultilayerPerceptron對MNIST的0-9數字進行識別
阿新 • • 發佈:2019-01-08
:神經網路介紹
http://ufldl.stanford.edu/tutorial/supervised/MultiLayerNeuralNetworks/
:由於只採用一種(28 * 28, 100, 50, 10)層進行訓練,效果不是很好
package com.bbw5.ml.spark import org.apache.spark.ml.tuning.ParamGridBuilder import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.tuning.TrainValidationSplit import com.bbw5.ml.spark.data.MNISTData import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.tuning.CrossValidator import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.SparkConf import org.apache.spark.ml.param.IntArrayParam import scala.collection.mutable.ArrayBuffer import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel object MultilayerPerceptron4MNIST { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("MultilayerPerceptron4MNIST") val sc = new SparkContext(sparkConf) val sqlContext = new org.apache.spark.sql.SQLContext(sc) tvSplit(sc,sqlContext) } def tvSplit(sc: SparkContext, sqlContext: SQLContext) { val dataDir = "I:/DM-dataset/MNIST/" import sqlContext.implicits._ //取部分資料進行測試,不然記憶體不夠 val training = sc.parallelize(MNISTData.loadTrainData(dataDir, 10).toSeq, 4). randomSplit(Array(0.1, 0.9), seed = 1234L)(0).toDF("label", "features").cache() val test = sc.parallelize(MNISTData.loadTestData(dataDir, 10).toSeq, 4). randomSplit(Array(0.1, 0.9), seed = 1234L)(0).toDF("label", "features").cache() val mpc = new MultilayerPerceptronClassifier() val paramGrid = new ParamGridBuilder(). addGrid[Array[Int]](mpc.layers, ArrayBuffer(Array[Int](28 * 28, 100, 50, 10))). addGrid(mpc.blockSize, Array(128)).addGrid(mpc.seed, Array(1234L)).addGrid(mpc.maxIter, Array(100)).build() val cv = new CrossValidator().setEstimator(mpc). setEvaluator(new MulticlassClassificationEvaluator). setEstimatorParamMaps(paramGrid).setNumFolds(3) // Run train validation split, and choose the best set of parameters. val model = cv.fit(training) sc.parallelize(Seq(model)).saveAsObjectFile("D:/Develop/Model/MNIST-MPC-" + System.currentTimeMillis()) val bestModel = model.bestModel.asInstanceOf[MultilayerPerceptronClassificationModel] println("model param:\n" + bestModel.extractParamMap) // Make predictions on test data. model is the model with combination of parameters // that performed best. val testDF = model.transform(test) testDF.select("label", "prediction").show() testDF.groupBy("label", "prediction").count().show() val predictionAndLabels = testDF.select("prediction", "label") val evaluator = new MulticlassClassificationEvaluator().setMetricName("precision") //0.231 println("Precision:" + evaluator.evaluate(predictionAndLabels)) } }