1. 程式人生 > >葡萄酒邏輯迴歸分類(scala實現)

葡萄酒邏輯迴歸分類(scala實現)

葡萄酒分類(scala實現)

分類方法:邏輯迴歸

 

其中

  •       0代表壞葡萄酒
  •       1代表好葡萄酒

訓練集中質量評分7.0以上被視為好葡萄酒


	import org.apache.spark.ml.classification.LogisticRegression
	import org.apache.spark.ml.param.ParamMap
	import org.apache.spark.ml.linalg.{Vector, Vectors}
	import org.apache.spark.ml.regression.LinearRegressionModel
	import org.apache.spark.sql.{Row, SparkSession}

	object LogicRegressWineClassifyDemo {
		def main(args: Array[String]): Unit = {
			val sess = SparkSession.builder().appName("ml").master("local[4]").getOrCreate();
			val sc = sess.sparkContext;
			//資料目錄
			val dataDir = "file:///D:/downloads/bigdata/ml/winequality-white.csv"
			//定義樣例類
			case class Wine(FixedAcidity: Double, VolatileAcidity: Double,
							CitricAcid: Double, ResidualSugar: Double, Chlorides: Double,
							FreeSulfurDioxide: Double, TotalSulfurDioxide: Double, Density: Double, PH:
							Double, Sulphates: Double, Alcohol: Double, Quality: Double)

			//變換
			val wineDataRDD = sc.textFile(dataDir).map(_.split(";")).map(w => Wine(w(0).toDouble, w(1).toDouble,
				w(2).toDouble, w(3).toDouble, w(4).toDouble, w(5).toDouble, w(6).toDouble, w(7).toDouble, w(8).toDouble
				, w(9).toDouble, w(10).toDouble, w(11).toDouble))

			import sess.implicits._

			//轉換RDD成DataFrame
			val trainingDF = wineDataRDD.map(w => (if (w.Quality < 7) 0D else 1D,
				Vectors.dense(w.FixedAcidity, w.VolatileAcidity, w.CitricAcid,
					w.ResidualSugar, w.Chlorides, w.FreeSulfurDioxide, w.TotalSulfurDioxide,
					w.Density, w.PH, w.Sulphates, w.Alcohol))).toDF("label", "features")

			//建立線性迴歸物件
			val lr = new LogisticRegression()
			//設定最大迭代次數
			lr.setMaxIter(10).setRegParam(0.01)
			//
			val model = lr.fit(trainingDF)
			//建立測試Dataframe
			val testDF = sess.createDataFrame(Seq((1.0,Vectors.dense(6.1, 0.32, 0.24, 1.5, 0.036, 43, 140, 0.9894, 3.36, 0.64, 10.7)),
				(0.0, Vectors.dense(5.2, 0.44, 0.04, 1.4, 0.036, 38, 124, 0.9898, 3.29, 0.42, 12.4)),
				(0.0,Vectors.dense(7.2, 0.32, 0.47, 5.1, 0.044, 19, 65, 0.9951, 3.38, 0.36, 9)),
				(0.0, Vectors.dense(6.4, 0.595, 0.14, 5.2, 0.058, 15, 97, 0.991, 3.03, 0.41, 12.6)))
			).toDF("label", "features")

			//顯式測試資料
			testDF.show();


			println("========================")
			//預測測試資料(帶標籤),評測模型的質量。
			testDF.createOrReplaceTempView("test")
			val tested = model.transform(testDF).select("features", "label", "prediction")
			tested.show();

			println("========================")
			//預測無標籤的測試資料。
			val predictDF = sess.sql("SELECT features FROM test")
			//預測結果
			val predicted = model.transform(predictDF).select("features", "prediction")
			predicted.show();
		}
	}