1. 程式人生 > >Spark 多項式邏輯回歸__二分類

Spark 多項式邏輯回歸__二分類

implicit frame sele 索引 ans gpa def 隱式 sse

package Spark_MLlib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
object 多項式邏輯回歸__二分類 { val spark=SparkSession.builder().master("local").getOrCreate() import spark.implicits._ //支持把一個RDD隱式轉換為一個DataFrame def main(args: Array[String]): Unit = { val df =spark.sparkContext.textFile("file:///home/soyo/桌面/spark編程測試數據/soyo.txt") .map(_.split(",")).map(x=>data_schema(Vectors.dense(x(0
).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF() df.show(130) df.createOrReplaceTempView("data_schema") val df_data=spark.sql("select * from data_schema where label !=‘soyo2‘") //這裏soyo2需要加單引號,不然報錯 // df_data.map(x=>x(1)+":"+x(0)).collect().foreach(println) df_data.show() val labelIndexer
=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df_data) val featureIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(df_data) //目的在特征向量中建類別索引 val Array(trainData,testData)=df_data.randomSplit(Array(0.7,0.3)) val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")//設置elasticnet混合參數為0.8,setFamily("multinomial"):設置為多項邏輯回歸,不設置setFamily為二項邏輯回歸 val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictionLabel").setLabels(labelIndexer.labels) val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featureIndexer,lr,labelConverter)) val lrPipeline_Model=lrPipeline.fit(trainData) val lrPrediction=lrPipeline_Model.transform(testData) lrPrediction.show(false) // lrPrediction.take(100).foreach(println) //模型評估 val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction") val lrAccuracy=evaluator.evaluate(lrPrediction) println("準確率為: "+lrAccuracy) val lrError=1-lrAccuracy println("錯誤率為: "+lrError) val LRmodel=lrPipeline_Model.stages(2).asInstanceOf[LogisticRegressionModel] println("二項邏輯回歸模型系數矩陣: "+LRmodel.coefficientMatrix) println("二項邏輯回歸模型的截距向量: "+LRmodel.interceptVector) println("類的數量(標簽可以使用的值): "+LRmodel.numClasses) println("模型所接受的特征的數量: "+LRmodel.numFeatures) } }

結果:


+-----------------+-----+
| features|label|
+-----------------+-----+
|[5.1,3.5,1.4,0.2]|soyo1|
|[4.9,3.0,1.4,0.2]|soyo1|
|[4.7,3.2,1.3,0.2]|soyo1|
|[4.6,3.1,1.5,0.2]|soyo1|
|[5.0,3.6,1.4,0.2]|soyo1|
|[5.4,3.9,1.7,0.4]|soyo1|
|[4.6,3.4,1.4,0.3]|soyo1|
|[5.0,3.4,1.5,0.2]|soyo1|
|[4.4,2.9,1.4,0.2]|soyo1|
|[4.9,3.1,1.5,0.1]|soyo1|
|[5.4,3.7,1.5,0.2]|soyo1|
|[4.8,3.4,1.6,0.2]|soyo1|
|[4.8,3.0,1.4,0.1]|soyo1|
|[4.3,3.0,1.1,0.1]|soyo1|
|[5.8,4.0,1.2,0.2]|soyo1|
|[5.7,4.4,1.5,0.4]|soyo1|
|[5.4,3.9,1.3,0.4]|soyo1|
|[5.1,3.5,1.4,0.3]|soyo1|
|[5.7,3.8,1.7,0.3]|soyo1|
|[5.1,3.8,1.5,0.3]|soyo1|
+-----------------+-----+
only showing top 20 rows

+-----------------+-----+------------+------------------+------------------------------------------+----------------------------------------+----------+---------------+
|features |label|indexedLabel|indexedFeatures |rawPrediction |probability |prediction|predictionLabel|
+-----------------+-----+------------+------------------+------------------------------------------+----------------------------------------+----------+---------------+
|[4.6,3.1,1.5,0.2]|soyo1|0.0 |[4.6,3.1,1.5,1.0] |[0.3841092104753886,-0.384109210475388] |[0.6831353764654857,0.3168646235345142] |0.0 |soyo1 |
|[4.6,3.2,1.4,0.2]|soyo1|0.0 |[4.6,3.2,1.4,1.0] |[0.4118074545189242,-0.41180745451892353] |[0.6950031457169539,0.3049968542830461] |0.0 |soyo1 |
|[4.6,3.4,1.4,0.3]|soyo1|0.0 |[4.6,3.4,1.4,2.0] |[0.41345332780578103,-0.41345332780578037]|[0.6957004614212158,0.30429953857878417]|0.0 |soyo1 |
|[4.7,3.2,1.6,0.2]|soyo1|0.0 |[4.7,3.2,1.6,1.0] |[0.39085103161962165,-0.390851031619621] |[0.6860468315498303,0.31395316845016974]|0.0 |soyo1 |
|[4.9,3.0,1.4,0.2]|soyo1|0.0 |[4.9,3.0,1.4,1.0] |[0.37736738933115554,-0.377367389331155] |[0.6802095073085258,0.3197904926914742] |0.0 |soyo1 |
|[4.9,3.1,1.5,0.1]|soyo1|0.0 |[4.9,3.1,1.5,0.0] |[0.4169034023763003,-0.4169034023762997] |[0.697159256477463,0.302840743522537] |0.0 |soyo1 |
|[5.0,3.0,1.6,0.2]|soyo1|0.0 |[5.0,3.0,1.6,1.0] |[0.356410966431853,-0.35641096643185244] |[0.6710244037082002,0.32897559629179984]|0.0 |soyo1 |
|[5.0,3.4,1.5,0.2]|soyo1|0.0 |[5.0,3.4,1.5,1.0] |[0.4357693082570414,-0.4357693082570408] |[0.705065751202206,0.2949342487977939] |0.0 |soyo1 |
|[5.0,3.4,1.6,0.4]|soyo1|0.0 |[5.0,3.4,1.6,3.0] |[0.35970271300556683,-0.35970271300556617]|[0.6724760743873281,0.3275239256126718] |0.0 |soyo1 |
|[5.1,3.4,1.5,0.2]|soyo1|0.0 |[5.1,3.4,1.5,1.0] |[0.4357693082570414,-0.4357693082570408] |[0.705065751202206,0.2949342487977939] |0.0 |soyo1 |
|[5.4,3.4,1.7,0.2]|soyo1|0.0 |[5.4,3.4,1.7,1.0] |[0.4148128853577389,-0.41481288535773825] |[0.6962757951954652,0.3037242048045349] |0.0 |soyo1 |
|[5.6,2.8,4.9,2.0]|soyo3|1.0 |[5.6,2.8,4.9,12.0]|[-0.3845461875044362,0.38454618750443703] |[0.3166754764713344,0.6833245235286656] |1.0 |soyo3 |
|[5.7,3.8,1.7,0.3]|soyo1|0.0 |[5.7,3.8,1.7,2.0] |[0.45089882383236457,-0.4508988238323638] |[0.7113187796385543,0.2886812203614457] |0.0 |soyo1 |
|[5.7,4.4,1.5,0.4]|soyo1|0.0 |[5.7,4.4,1.5,3.0] |[0.5423812503940613,-0.5423812503940606] |[0.7473941839256351,0.25260581607436505]|0.0 |soyo1 |
|[5.8,2.8,5.1,2.4]|soyo3|1.0 |[5.8,2.8,5.1,16.0]|[-0.5366793780073855,0.5366793780073863] |[0.2547648665744027,0.7452351334255972] |1.0 |soyo3 |
|[6.0,2.2,5.0,1.5]|soyo3|1.0 |[6.0,2.2,5.0,7.0] |[-0.3343736350128348,0.33437363501283546] |[0.3387774047228901,0.6612225952771099] |1.0 |soyo3 |
|[6.2,2.8,4.8,1.8]|soyo3|1.0 |[6.2,2.8,4.8,10.0]|[-0.3084795922529615,0.30847959225296234] |[0.3504733529544735,0.6495266470455265] |1.0 |soyo3 |
|[6.3,2.9,5.6,1.8]|soyo3|1.0 |[6.3,2.9,5.6,10.0]|[-0.3750852512562874,0.3750852512562882] |[0.3207841503157466,0.6792158496842534] |1.0 |soyo3 |
|[6.3,3.3,6.0,2.5]|soyo3|1.0 |[6.3,3.3,6.0,17.0]|[-0.5776773099857371,0.577677309985738] |[0.23951239936093965,0.7604876006390604]|1.0 |soyo3 |
|[6.3,3.4,5.6,2.4]|soyo3|1.0 |[6.3,3.4,5.6,16.0]|[-0.485750239692336,0.4857502396923369] |[0.2745815258875292,0.7254184741124707] |1.0 |soyo3 |
+-----------------+-----+------------+------------------+------------------------------------------+----------------------------------------+----------+---------------+
only showing top 20 rows

準確率為: 1.0
錯誤率為: 0.0
二項邏輯回歸模型系數矩陣: 0.0 0.17220032593884316 -0.1047821144965127 -0.03279419190091169
0.0 -0.172200325938843 0.10478211449651276 0.03279419190091169
二項邏輯回歸模型的截距向量: [0.04025556371065551,-0.04025556371065551]
類的數量(標簽可以使用的值): 2
模型所接受的特征的數量: 4

Spark 多項式邏輯回歸__二分類