1. 程式人生 > >mllib實踐(二)之LinearRegression實踐(DataFrame方式,普通標籤格式轉DataFrame)(整合網際網路上多個例項)

mllib實踐(二)之LinearRegression實踐(DataFrame方式,普通標籤格式轉DataFrame)(整合網際網路上多個例項)

package mllib;

import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.mllib.regression.{ LabeledPoint, LinearRegressionWithSGD }
import org.apache.spark.sql.{SparkSession,DataFrame,SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.feature.VectorAssembler

object App {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("kimiYang");
    val sc = new SparkContext(conf);
     val sqc=new SQLContext(sc)
     
    val spark=  SparkSession.builder().appName("test").config("spark.some.config.option", "some-value").getOrCreate()
        import spark.implicits._
        
    //val data = sc.textFile("/test/kimi.txt");
    val data = spark.read.text("/home/hadoop/mllibdata/kimi.txt");
    val parseData = data.map { case Row(line: String) =>
      val parts = line.split(',') //根據逗號進行分割槽
      val p2=parts(1).split(' ')
      (parts(0).toDouble, p2(0).toDouble, p2(1).toDouble)
    }.toDF("y","w1","w2") //轉化資料格式
  
     val colArray = Array("w1", "w2")
    val assembler = new VectorAssembler().setInputCols(colArray).setOutputCol("features")
     val vecDF: DataFrame = assembler.transform(parseData)
     parseData.show();

    //val model = LinearRegressionWithSGD.train(parseData, 100, 0.1) //建立模型
    // 建立模型,預測謀殺率Murder
    // 設定線性迴歸引數
    val lr1 = new LinearRegression()
    val lr2 = lr1.setFeaturesCol("features").setLabelCol("y").setFitIntercept(true)
    // RegParam:正則化
    val lr3 = lr2.setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
    val lr = lr3
      
   // 將訓練集合代入模型進行訓練 
    val model = lr.fit(vecDF)
    // 輸出模型全部引數
model.extractParamMap()
 println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")
 
 
 
  //準備預測集合
    val raw_data_predict=sc.textFile("/home/hadoop/mllibdata/kimi.txt")
    val map_data_for_predict=raw_data_predict.map{x=>
      val split_list=x.split(",")
      val p2=split_list(1).split(' ')
      (split_list(0).toDouble, p2(0).toDouble, p2(1).toDouble)
     }
    val df_for_predict=sqc.createDataFrame(map_data_for_predict)
    val data_for_predict = df_for_predict.toDF("y", "w1", "w2")
    val colArray_for_predict = Array("w1","w2")
    val assembler_for_predict = new VectorAssembler().setInputCols(colArray_for_predict).setOutputCol("features")
    val vecDF_for_predict: DataFrame = assembler_for_predict.transform(data_for_predict)
 
    
    
    
    //通過模型預測模型
    // 對樣本進行測試
        // 模型進行評價
    val trainingSummary = model.summary
    println(s"numIterations: ${trainingSummary.totalIterations}")
    println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
// Obtain the loss per iteration.
//每次迭代的損失,一般會逐漸減小
//double[] objectiveHistory = trainingSummary.objectiveHistory();
//for (double lossPerIteration : objectiveHistory) {
//  System.out.println(lossPerIteration);
//}


    trainingSummary.residuals.show()
    println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")//RMSE:均方根差
    println(s"r2: ${trainingSummary.r2}")//r2:判定係數,也稱為擬合優度,越接近1越好
 
 
    val predictions: DataFrame = model.transform(vecDF_for_predict)
    //    val predictions = lrModel.transform(vecDF)
    println("輸出預測結果")
    val predict_result: DataFrame =predictions.selectExpr("features","y", "round(prediction,1) as prediction")
    predict_result.foreach(println(_))
    
    sc.stop
  }
}

 

輸出結果:

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
18/10/17 11:24:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
18/10/17 11:24:56 WARN Utils: Your hostname, dblab-VirtualBox resolves to a loopback address: 127.0.1.1; using 10.0.2.4 instead (on interface enp0s3)
18/10/17 11:24:56 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
18/10/17 11:24:58 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
18/10/17 11:24:59 WARN SparkSession$Builder: Using an existing SparkSession; some configuration may not take effect.
+----+---+---+
|   y| w1| w2|
+----+---+---+
| 5.0|1.0|1.0|
| 7.0|2.0|1.0|
| 9.0|3.0|2.0|
|11.0|4.0|1.0|
|19.0|5.0|3.0|
|18.0|6.0|2.0|
+----+---+---+

18/10/17 11:25:12 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
18/10/17 11:25:12 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
Coefficients: [2.1834836708105665,2.0791699982781844] Intercept: 0.392523821699377
numIterations: 6
objectiveHistory: List(0.5, 0.3905266284749149, 0.07485179411449853, 0.07261563579181778, 0.06909771042678862, 0.06909771037948713)
+-------------------+
|          residuals|
+-------------------+
| 0.3448225092118724|
|0.16133883840130547|
|-2.1013148306874463|
|-0.2056285032198275|
| 1.4525478294132341|
|0.34823415688085646|
+-------------------+

RMSE: 1.0672317857734335
r2: 0.9592005844334871
輸出預測結果
[[1.0,1.0],5.0,4.7]
[[2.0,1.0],7.0,6.8]
[[3.0,2.0],9.0,11.1]
[[4.0,1.0],11.0,11.2]
[[5.0,3.0],19.0,17.5]
[[6.0,2.0],18.0,17.7]