1. 程式人生 > >Spark 決策樹--回歸模型

Spark 決策樹--回歸模型

pipe sele nal evaluate 回歸 textfile style mode ssi

package Spark_MLlib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
case class data_scheam(features:Vector,label:String) object 決策樹__回歸模型 { val spark=SparkSession.builder().master("local").getOrCreate() import spark.implicits._ def main(args: Array[String]): Unit = { val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark編程測試數據/soyo2.txt") .map(_.split(
",")).map(x=>data_schema(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF() val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data) val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("
indexedFeatures").setMaxCategories(4).fit(data) val labelCoverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3)) //決策樹回歸模型構造設置 val dtRegressor=new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") //構造機器學習工作流 val pipelineRegressor=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,dtRegressor,labelCoverter)) //訓練決策樹回歸模型 val modelRegressor=pipelineRegressor.fit(trainData) //進行預測 val prediction=modelRegressor.transform(testData) prediction.show(150) //評估決策樹回歸模型 val evaluatorRegressor=new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse") //setMetricName:設置決定你的度量標準是均方根誤差還是均方誤差等,值可以為:rmse,mse,r2,mae
val Root_Mean_Squared_Error=evaluatorRegressor.evaluate(prediction) println("均方根誤差為: "+Root_Mean_Squared_Error) val treeModelRegressor=modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel] val schema_decisionTree=treeModelRegressor.toDebugString println("決策樹分類模型的結構為: "+schema_decisionTree) } }
Spark 源碼:關於setMetricName("")
@Since("2.0.0")
  override def evaluate(dataset: Dataset[_]): Double = {
    val schema = dataset.schema
    SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
    SchemaUtils.checkNumericType(schema, $(labelCol))

    val predictionAndLabels = dataset
      .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
      .rdd
      .map { case Row(prediction: Double, label: Double) => (prediction, label) }
    val metrics = new RegressionMetrics(predictionAndLabels)
    val metric = $(metricName) match {
      case "rmse" => metrics.rootMeanSquaredError
      case "mse" => metrics.meanSquaredError
      case "r2" => metrics.r2
      case "mae" => metrics.meanAbsoluteError
    }
    metric
}

結果:

+-----------------+------+------------+-----------------+----------+--------------+
| features| label|indexedLabel| indexedFeatures|prediction|predictedLabel|
+-----------------+------+------------+-----------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|hadoop| 1.0|[4.6,3.1,1.5,0.2]| 1.0| hadoop|
|[4.6,3.4,1.4,0.3]|hadoop| 1.0|[4.6,3.4,1.4,0.3]| 1.0| hadoop|
|[4.7,3.2,1.3,0.2]|hadoop| 1.0|[4.7,3.2,1.3,0.2]| 1.0| hadoop|
|[4.8,3.0,1.4,0.1]|hadoop| 1.0|[4.8,3.0,1.4,0.1]| 1.0| hadoop|
|[5.1,3.3,1.7,0.5]|hadoop| 1.0|[5.1,3.3,1.7,0.5]| 1.0| hadoop|
|[5.1,3.7,1.5,0.4]|hadoop| 1.0|[5.1,3.7,1.5,0.4]| 1.0| hadoop|
|[5.4,3.9,1.3,0.4]|hadoop| 1.0|[5.4,3.9,1.3,0.4]| 1.0| hadoop|
|[5.5,2.3,4.0,1.3]| spark| 0.0|[5.5,2.3,4.0,1.3]| 0.0| spark|
|[5.5,3.5,1.3,0.2]|hadoop| 1.0|[5.5,3.5,1.3,0.2]| 1.0| hadoop|
|[5.6,2.7,4.2,1.3]| spark| 0.0|[5.6,2.7,4.2,1.3]| 0.0| spark|
|[5.6,3.0,4.1,1.3]| spark| 0.0|[5.6,3.0,4.1,1.3]| 0.0| spark|
|[5.6,3.0,4.5,1.5]| spark| 0.0|[5.6,3.0,4.5,1.5]| 0.0| spark|
|[5.7,2.6,3.5,1.0]| spark| 0.0|[5.7,2.6,3.5,1.0]| 0.0| spark|
|[5.7,4.4,1.5,0.4]|hadoop| 1.0|[5.7,4.4,1.5,0.4]| 1.0| hadoop|
|[5.8,2.7,3.9,1.2]| spark| 0.0|[5.8,2.7,3.9,1.2]| 0.0| spark|
|[5.8,2.7,4.1,1.0]| spark| 0.0|[5.8,2.7,4.1,1.0]| 0.0| spark|
|[5.8,2.8,5.1,2.4]| Scala| 2.0|[5.8,2.8,5.1,2.4]| 2.0| Scala|
|[5.8,4.0,1.2,0.2]|hadoop| 1.0|[5.8,4.0,1.2,0.2]| 1.0| hadoop|
|[5.9,3.0,4.2,1.5]| spark| 0.0|[5.9,3.0,4.2,1.5]| 0.0| spark|
|[5.9,3.0,5.1,1.8]| Scala| 2.0|[5.9,3.0,5.1,1.8]| 2.0| Scala|
|[5.9,3.2,4.8,1.8]| spark| 0.0|[5.9,3.2,4.8,1.8]| 2.0| Scala|
|[6.1,2.6,5.6,1.4]| Scala| 2.0|[6.1,2.6,5.6,1.4]| 2.0| Scala|
|[6.1,2.8,4.0,1.3]| spark| 0.0|[6.1,2.8,4.0,1.3]| 0.0| spark|
|[6.3,2.9,5.6,1.8]| Scala| 2.0|[6.3,2.9,5.6,1.8]| 2.0| Scala|
|[6.3,3.4,5.6,2.4]| Scala| 2.0|[6.3,3.4,5.6,2.4]| 2.0| Scala|
|[6.4,2.7,5.3,1.9]| Scala| 2.0|[6.4,2.7,5.3,1.9]| 2.0| Scala|
|[6.4,3.1,5.5,1.8]| Scala| 2.0|[6.4,3.1,5.5,1.8]| 2.0| Scala|
|[6.4,3.2,4.5,1.5]| spark| 0.0|[6.4,3.2,4.5,1.5]| 0.0| spark|
|[6.5,2.8,4.6,1.5]| spark| 0.0|[6.5,2.8,4.6,1.5]| 0.0| spark|
|[6.5,3.0,5.5,1.8]| Scala| 2.0|[6.5,3.0,5.5,1.8]| 2.0| Scala|
|[6.7,3.0,5.2,2.3]| Scala| 2.0|[6.7,3.0,5.2,2.3]| 2.0| Scala|
|[6.7,3.1,4.7,1.5]| spark| 0.0|[6.7,3.1,4.7,1.5]| 0.0| spark|
|[6.8,3.0,5.5,2.1]| Scala| 2.0|[6.8,3.0,5.5,2.1]| 2.0| Scala|
|[6.9,3.1,5.4,2.1]| Scala| 2.0|[6.9,3.1,5.4,2.1]| 2.0| Scala|
|[7.0,3.2,4.7,1.4]| spark| 0.0|[7.0,3.2,4.7,1.4]| 0.0| spark|
|[7.1,3.0,5.9,2.1]| Scala| 2.0|[7.1,3.0,5.9,2.1]| 2.0| Scala|
|[7.2,3.0,5.8,1.6]| Scala| 2.0|[7.2,3.0,5.8,1.6]| 0.0| spark|
|[7.2,3.2,6.0,1.8]| Scala| 2.0|[7.2,3.2,6.0,1.8]| 2.0| Scala|
|[7.2,3.6,6.1,2.5]| Scala| 2.0|[7.2,3.6,6.1,2.5]| 2.0| Scala|
|[7.4,2.8,6.1,1.9]| Scala| 2.0|[7.4,2.8,6.1,1.9]| 2.0| Scala|
|[7.7,2.6,6.9,2.3]| Scala| 2.0|[7.7,2.6,6.9,2.3]| 2.0| Scala|
|[7.7,2.8,6.7,2.0]| Scala| 2.0|[7.7,2.8,6.7,2.0]| 2.0| Scala|
+-----------------+------+------------+-----------------+----------+--------------+

均方根誤差為: 0.43643578047198484
決策樹分類模型的結構為: DecisionTreeRegressionModel (uid=dtr_6015411b1a3d) of depth 4 with 11 nodes
If (feature 3 <= 1.7)
If (feature 2 <= 1.9)
Predict: 1.0
Else (feature 2 > 1.9)
If (feature 2 <= 4.9)
If (feature 3 <= 1.6)
Predict: 0.0
Else (feature 3 > 1.6)
Predict: 2.0
Else (feature 2 > 4.9)
If (feature 3 <= 1.5)
Predict: 2.0
Else (feature 3 > 1.5)
Predict: 0.0
Else (feature 3 > 1.7)
Predict: 2.0

Spark 決策樹--回歸模型