1. 程式人生 > >【spark】採用MultilayerPerceptron對MNIST的0-9數字進行識別

【spark】採用MultilayerPerceptron對MNIST的0-9數字進行識別

:神經網路介紹

http://ufldl.stanford.edu/tutorial/supervised/MultiLayerNeuralNetworks/

:由於只採用一種(28 * 28, 100, 50, 10)層進行訓練,效果不是很好

package com.bbw5.ml.spark

import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.TrainValidationSplit
import com.bbw5.ml.spark.data.MNISTData
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.SparkConf
import org.apache.spark.ml.param.IntArrayParam
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel

object MultilayerPerceptron4MNIST {
  def main(args: Array[String]) {
    val sparkConf = new SparkConf().setAppName("MultilayerPerceptron4MNIST")
    val sc = new SparkContext(sparkConf)
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    
    tvSplit(sc,sqlContext)
  }

  def tvSplit(sc: SparkContext, sqlContext: SQLContext) {
    val dataDir = "I:/DM-dataset/MNIST/"
    import sqlContext.implicits._
    //取部分資料進行測試,不然記憶體不夠
    val training = sc.parallelize(MNISTData.loadTrainData(dataDir, 10).toSeq, 4).
      randomSplit(Array(0.1, 0.9), seed = 1234L)(0).toDF("label", "features").cache()
    val test = sc.parallelize(MNISTData.loadTestData(dataDir, 10).toSeq, 4).
      randomSplit(Array(0.1, 0.9), seed = 1234L)(0).toDF("label", "features").cache()

    val mpc = new MultilayerPerceptronClassifier()

    val paramGrid = new ParamGridBuilder().
      addGrid[Array[Int]](mpc.layers, ArrayBuffer(Array[Int](28 * 28, 100, 50, 10))).
      addGrid(mpc.blockSize, Array(128)).addGrid(mpc.seed, Array(1234L)).addGrid(mpc.maxIter, Array(100)).build()

    val cv = new CrossValidator().setEstimator(mpc).
      setEvaluator(new MulticlassClassificationEvaluator).
      setEstimatorParamMaps(paramGrid).setNumFolds(3)

    // Run train validation split, and choose the best set of parameters.
    val model = cv.fit(training)
    sc.parallelize(Seq(model)).saveAsObjectFile("D:/Develop/Model/MNIST-MPC-" + System.currentTimeMillis())

    val bestModel = model.bestModel.asInstanceOf[MultilayerPerceptronClassificationModel]
    println("model param:\n" + bestModel.extractParamMap)

    // Make predictions on test data. model is the model with combination of parameters
    // that performed best.
    val testDF = model.transform(test)
    testDF.select("label", "prediction").show()
    testDF.groupBy("label", "prediction").count().show()

    val predictionAndLabels = testDF.select("prediction", "label")
    val evaluator = new MulticlassClassificationEvaluator().setMetricName("precision")
    //0.231
    println("Precision:" + evaluator.evaluate(predictionAndLabels))
  }
}