1. 程式人生 > >Spark 二項邏輯回歸__二分類

Spark 二項邏輯回歸__二分類

tag tostring ont sch ray park pip threshold map

package Spark_MLlib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions
case class data_schema(features:Vector,label:String) object 二項邏輯回歸__二分類 { val spark=SparkSession.builder().master("local").getOrCreate() import spark.implicits._ //支持把一個RDD隱式轉換為一個DataFrame def main(args: Array[String]): Unit = { val df =spark.sparkContext.textFile("file:///home/soyo/桌面/spark編程測試數據/soyo.txt
") .map(_.split(",")).map(x=>data_schema(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF() df.show(130) df.createOrReplaceTempView("data_schema") val df_data=spark.sql("select * from data_schema where label !=‘soyo2‘") //這裏soyo2需要加單引號,不然報錯
// df_data.map(x=>x(1)+":"+x(0)).collect().foreach(println) df_data.show() val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df_data) val featureIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(df_data) //目的在特征向量中建類別索引 val Array(trainData,testData)=df_data.randomSplit(Array(0.7,0.3)) val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.5).setElasticNetParam(0.8)//setRegParam:正則化參數,設置elasticnet混合參數為0.8,setFamily("multinomial"):設置為多項邏輯回歸,不設置setFamily為二項邏輯回歸 val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictionLabel").setLabels(labelIndexer.labels) val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featureIndexer,lr,labelConverter)) val lrPipeline_Model=lrPipeline.fit(trainData) val lrPrediction=lrPipeline_Model.transform(testData) lrPrediction.show(false) // lrPrediction.take(100).foreach(println) //模型評估 val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction") val lrAccuracy=evaluator.evaluate(lrPrediction) println("準確率為: "+lrAccuracy) val lrError=1-lrAccuracy println("錯誤率為: "+lrError) val LRmodel=lrPipeline_Model.stages(2).asInstanceOf[LogisticRegressionModel] println("二項邏輯回歸模型系數的向量: "+LRmodel.coefficients) println("二項邏輯回歸模型的截距: "+LRmodel.intercept) println("類的數量(標簽可以使用的值): "+LRmodel.numClasses) println("模型所接受的特征的數量: "+LRmodel.numFeatures) //對模型的總結(summary)目前只支持二項邏輯斯蒂回歸,多項式邏輯回歸並不支持(用的是spark 2.2.0) println(LRmodel.hasSummary) val trainingSummary = LRmodel.summary //損失函數,可以看到損失函數隨著循環是逐漸變小的,損失函數越小,模型就越好 val objectiveHistory =trainingSummary.objectiveHistory objectiveHistory.foreach(println) //強制轉換為BinaryLogisticRegressionSummary val binarySummary= trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] //ROC曲線下方的面積,越接近1說明模型越好 val area_ROC=binarySummary.areaUnderROC println("ROC 曲線下的面積為: "+area_ROC) //fMeasureByThreshold:返回一個帶有beta = 1.0的兩個字段(閾值,f - measure)曲線的dataframe val fMeasure=binarySummary.fMeasureByThreshold println("fMeasure的行數: "+fMeasure.collect().length) fMeasure.show(100) val maxFMeasure=fMeasure.select(functions.max("F-Measure")).head().getDouble(0) println("最大的F-Measure的值為: "+maxFMeasure) //最優的閥值 val bestThreashold=fMeasure.where($"F-Measure"===maxFMeasure).select("threshold").head().getDouble(0) println("最優的閥值為:"+bestThreashold) /* 這樣求的不是最優的閥值 val s=fMeasure.select(functions.max("threshold")).head().getDouble(0) println(s) */ LRmodel.setThreshold(bestThreashold) } }

結果:


+-----------------+-----+------------+------------------+--------------------------------------------+----------------------------------------+----------+---------------+
|features |label|indexedLabel|indexedFeatures |rawPrediction |probability |prediction|predictionLabel|
+-----------------+-----+------------+------------------+--------------------------------------------+----------------------------------------+----------+---------------+
|[4.4,2.9,1.4,0.2]|soyo1|0.0 |[4.4,2.9,1.4,1.0] |[0.0690256519103008,-0.0690256519103008] |[0.5172495646670774,0.48275043533292256]|0.0 |soyo1 |
|[4.4,3.0,1.3,0.2]|soyo1|0.0 |[4.4,3.0,1.3,1.0] |[0.07401171769156373,-0.07401171769156373] |[0.518494487869238,0.481505512130762] |0.0 |soyo1 |
|[4.6,3.1,1.5,0.2]|soyo1|0.0 |[4.6,3.1,1.5,1.0] |[0.06403958612903785,-0.06403958612903785] |[0.5160044273015656,0.48399557269843435]|0.0 |soyo1 |
|[4.6,3.2,1.4,0.2]|soyo1|0.0 |[4.6,3.2,1.4,1.0] |[0.0690256519103008,-0.0690256519103008] |[0.5172495646670774,0.48275043533292256]|0.0 |soyo1 |
|[4.6,3.6,1.0,0.2]|soyo1|0.0 |[4.6,3.6,1.0,1.0] |[0.08896991503535255,-0.08896991503535255] |[0.5222278183980882,0.4777721816019118] |0.0 |soyo1 |
|[4.8,3.0,1.4,0.1]|soyo1|0.0 |[4.8,3.0,1.4,0.0] |[0.0690256519103008,-0.0690256519103008] |[0.5172495646670774,0.48275043533292256]|0.0 |soyo1 |
|[4.9,2.5,4.5,1.7]|soyo3|1.0 |[4.9,2.5,4.5,9.0] |[-0.08554238730885033,0.08554238730885033] |[0.47862743439605193,0.5213725656039481]|1.0 |soyo3 |
|[5.0,3.0,1.6,0.2]|soyo1|0.0 |[5.0,3.0,1.6,1.0] |[0.059053520347774904,-0.059053520347774904]|[0.5147590911988562,0.48524090880114373]|0.0 |soyo1 |
|[5.1,3.5,1.4,0.3]|soyo1|0.0 |[5.1,3.5,1.4,2.0] |[0.0690256519103008,-0.0690256519103008] |[0.5172495646670774,0.48275043533292256]|0.0 |soyo1 |
|[5.1,3.8,1.6,0.2]|soyo1|0.0 |[5.1,3.8,1.6,1.0] |[0.059053520347774904,-0.059053520347774904]|[0.5147590911988562,0.48524090880114373]|0.0 |soyo1 |
|[5.3,3.7,1.5,0.2]|soyo1|0.0 |[5.3,3.7,1.5,1.0] |[0.06403958612903785,-0.06403958612903785] |[0.5160044273015656,0.48399557269843435]|0.0 |soyo1 |
|[5.4,3.7,1.5,0.2]|soyo1|0.0 |[5.4,3.7,1.5,1.0] |[0.06403958612903785,-0.06403958612903785] |[0.5160044273015656,0.48399557269843435]|0.0 |soyo1 |
|[5.4,3.9,1.7,0.4]|soyo1|0.0 |[5.4,3.9,1.7,3.0] |[0.05406745456651198,-0.05406745456651198] |[0.5135135717949689,0.486486428205031] |0.0 |soyo1 |
|[5.7,3.8,1.7,0.3]|soyo1|0.0 |[5.7,3.8,1.7,2.0] |[0.05406745456651198,-0.05406745456651198] |[0.5135135717949689,0.486486428205031] |0.0 |soyo1 |
|[5.8,2.8,5.1,2.4]|soyo3|1.0 |[5.8,2.8,5.1,16.0]|[-0.11545878199642795,0.11545878199642795] |[0.4711673274353307,0.5288326725646694] |1.0 |soyo3 |
|[5.8,4.0,1.2,0.2]|soyo1|0.0 |[5.8,4.0,1.2,1.0] |[0.07899778347282668,-0.07899778347282668] |[0.5197391814925231,0.480260818507477] |0.0 |soyo1 |
|[6.1,3.0,4.9,1.8]|soyo3|1.0 |[6.1,3.0,4.9,10.0]|[-0.10548665043390212,0.10548665043390212] |[0.4736527642876721,0.5263472357123279] |1.0 |soyo3 |
|[6.3,2.7,4.9,1.8]|soyo3|1.0 |[6.3,2.7,4.9,10.0]|[-0.10548665043390212,0.10548665043390212] |[0.4736527642876721,0.5263472357123279] |1.0 |soyo3 |
|[6.3,2.9,5.6,1.8]|soyo3|1.0 |[6.3,2.9,5.6,10.0]|[-0.14038911090274264,0.14038911090274264] |[0.46496025354157383,0.5350397464584261]|1.0 |soyo3 |
|[6.5,3.0,5.5,1.8]|soyo3|1.0 |[6.5,3.0,5.5,10.0]|[-0.13540304512147971,0.13540304512147971] |[0.4662008623530858,0.5337991376469143] |1.0 |soyo3 |
+-----------------+-----+------------+------------------+--------------------------------------------+----------------------------------------+----------+---------------+
only showing top 20 rows

準確率為: 1.0
錯誤率為: 0.0
二項邏輯回歸模型系數的向量: [0.0,0.0,0.0498606578126294,-0.0]
二項邏輯回歸模型的截距: -0.13883057284798195
類的數量(標簽可以使用的值): 2
模型所接受的特征的數量: 4
true
0.6927819059876479
0.6921535505946383
0.6902127176671448
0.6898394130469451
0.689535794969328
0.6894009255584304
0.6893497986701255
0.689265433291139
0.6887228224555286
0.6895877386375889
0.6872109190567809
ROC 曲線下的面積為: 1.0
fMeasure的行數: 26
+-------------------+-------------------+
| threshold| F-Measure|
+-------------------+-------------------+
| 0.5511227178429281|0.05128205128205127|
| 0.5486545095952616| 0.1|
| 0.547419499422364|0.14634146341463414|
| 0.5449477416103359| 0.1904761904761905|
| 0.5412359859690851| 0.2727272727272727|
| 0.5399976958289747|0.34782608695652173|
| 0.5387589116841329|0.38297872340425526|
| 0.5375196486465557| 0.4799999999999999|
| 0.5362799218518347| 0.5098039215686275|
| 0.5350397464584261| 0.6428571428571429|
| 0.5337991376469143| 0.6896551724137931|
| 0.5325581106192748| 0.7333333333333334|
| 0.5313166805981351| 0.7741935483870968|
| 0.5300748628260323| 0.8125000000000001|
| 0.5288326725646694| 0.9142857142857143|
| 0.5275901250941695| 0.958904109589041|
| 0.5263472357123279| 0.972972972972973|
| 0.5251040197338624| 1.0|
| 0.4889779551275146| 0.9743589743589743|
| 0.486486428205031| 0.9500000000000001|
|0.48524090880114373| 0.8941176470588235|
|0.48399557269843435| 0.7916666666666666|
|0.48275043533292256| 0.7307692307692308|
| 0.481505512130762| 0.6909090909090909|
| 0.480260818507477| 0.6846846846846847|
|0.47901636986720014| 0.6785714285714285|
+-------------------+-------------------+

最大的F-Measure的值為: 1.0
最優的閥值為:0.5251040197338624

Spark 二項邏輯回歸__二分類