葡萄酒邏輯迴歸分類(scala實現)
阿新 • • 發佈:2018-12-24
葡萄酒分類(scala實現)
分類方法:邏輯迴歸
其中
- 0代表壞葡萄酒
- 1代表好葡萄酒
訓練集中質量評分7.0以上被視為好葡萄酒
import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.regression.LinearRegressionModel import org.apache.spark.sql.{Row, SparkSession} object LogicRegressWineClassifyDemo { def main(args: Array[String]): Unit = { val sess = SparkSession.builder().appName("ml").master("local[4]").getOrCreate(); val sc = sess.sparkContext; //資料目錄 val dataDir = "file:///D:/downloads/bigdata/ml/winequality-white.csv" //定義樣例類 case class Wine(FixedAcidity: Double, VolatileAcidity: Double, CitricAcid: Double, ResidualSugar: Double, Chlorides: Double, FreeSulfurDioxide: Double, TotalSulfurDioxide: Double, Density: Double, PH: Double, Sulphates: Double, Alcohol: Double, Quality: Double) //變換 val wineDataRDD = sc.textFile(dataDir).map(_.split(";")).map(w => Wine(w(0).toDouble, w(1).toDouble, w(2).toDouble, w(3).toDouble, w(4).toDouble, w(5).toDouble, w(6).toDouble, w(7).toDouble, w(8).toDouble , w(9).toDouble, w(10).toDouble, w(11).toDouble)) import sess.implicits._ //轉換RDD成DataFrame val trainingDF = wineDataRDD.map(w => (if (w.Quality < 7) 0D else 1D, Vectors.dense(w.FixedAcidity, w.VolatileAcidity, w.CitricAcid, w.ResidualSugar, w.Chlorides, w.FreeSulfurDioxide, w.TotalSulfurDioxide, w.Density, w.PH, w.Sulphates, w.Alcohol))).toDF("label", "features") //建立線性迴歸物件 val lr = new LogisticRegression() //設定最大迭代次數 lr.setMaxIter(10).setRegParam(0.01) // val model = lr.fit(trainingDF) //建立測試Dataframe val testDF = sess.createDataFrame(Seq((1.0,Vectors.dense(6.1, 0.32, 0.24, 1.5, 0.036, 43, 140, 0.9894, 3.36, 0.64, 10.7)), (0.0, Vectors.dense(5.2, 0.44, 0.04, 1.4, 0.036, 38, 124, 0.9898, 3.29, 0.42, 12.4)), (0.0,Vectors.dense(7.2, 0.32, 0.47, 5.1, 0.044, 19, 65, 0.9951, 3.38, 0.36, 9)), (0.0, Vectors.dense(6.4, 0.595, 0.14, 5.2, 0.058, 15, 97, 0.991, 3.03, 0.41, 12.6))) ).toDF("label", "features") //顯式測試資料 testDF.show(); println("========================") //預測測試資料(帶標籤),評測模型的質量。 testDF.createOrReplaceTempView("test") val tested = model.transform(testDF).select("features", "label", "prediction") tested.show(); println("========================") //預測無標籤的測試資料。 val predictDF = sess.sql("SELECT features FROM test") //預測結果 val predicted = model.transform(predictDF).select("features", "prediction") predicted.show(); } }