1. 程式人生 > >Spark2 機器學習之決策樹分類Decision tree classifier

Spark2 機器學習之決策樹分類Decision tree classifier

show(10,truncate=false) +-------+------+----+------------+--------+-------------+---------+----------+------+ |affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating| +-------+------+----+------------+--------+-------------+---------+----------+------+ |0.0 |male |37.0|10.0
|no |3.0 |18.0 |7.0 |4.0 | |0.0 |female|27.0|4.0 |no |4.0 |14.0 |6.0 |4.0 | |0.0 |female|32.0|15.0 |yes |1.0 |12.0 |1.0 |4.0 | |0.0 |male |57.0|15.0 |yes |5.0 |18.0 |6.0 |5.0 | |0.0 |male |22.0
|0.75 |no |2.0 |17.0 |6.0 |3.0 | |0.0 |female|32.0|1.5 |no |2.0 |17.0 |5.0 |5.0 | |0.0 |female|22.0|0.75 |no |2.0 |12.0 |1.0 |3.0 | |0.0 |male |57.0|15.0 |yes |2.0 |14.0 |4.0 |4.0 | |0.0 |female|32.0
|15.0 |yes |4.0 |16.0 |1.0 |2.0 | |0.0 |male |22.0|1.5 |no |4.0 |14.0 |4.0 |5.0 | +-------+------+----+------------+--------+-------------+---------+----------+------+ only showing top 10 rows // 檢視資料分佈情況 data.describe("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating").show(10,truncate=false) +-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+ |summary|affairs |gender|age |yearsmarried |children|religiousness |education |occupation |rating | +-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+ |count |601 |601 |601 |601 |601 |601 |601 |601 |601 | |mean |1.4559068219633944|null |32.48752079866888|8.17769550748752 |null |3.1164725457570714|16.16638935108153|4.194675540765391|3.9317803660565724| |stddev |3.298757728494681 |null |9.28876170487667 |5.571303149963791|null |1.1675094016730692|2.402554565766698|1.819442662708579|1.1031794920503795| |min |0.0 |female|17.5 |0.125 |no |1.0 |9.0 |1.0 |1.0 | |max |12.0 |male |57.0 |15.0 |yes |5.0 |20.0 |7.0 |5.0 | +-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+ data.createOrReplaceTempView("data") // 字元型別轉換成數值 val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label" labelWhere: String = case when affairs=0 then 0 else cast(1 as double) end as label val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender" genderWhere: String = case when gender='female' then 0 else cast(1 as double) end as gender val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children" childrenWhere: String = case when children='no' then 0 else cast(1 as double) end as children val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data") dataLabelDF: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 7 more fields] val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") featuresArray: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating) // 欄位轉換成特徵向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") assembler: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_6e2c6bdd631e val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 8 more fields] vecDF.show(10,truncate=false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]| |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]| |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]| |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]| |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]| |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]| |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]| |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ only showing top 10 rows // 索引標籤,將元資料新增到標籤列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_d00cad619cd5 labelIndexer.transform(vecDF).show(10,truncate=false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |indexedLabel| +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|0.0 | |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |0.0 | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|0.0 | |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|0.0 | |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|0.0 | |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |0.0 | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|0.0 | |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|0.0 | |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|0.0 | |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |0.0 | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+ only showing top 10 rows // 自動識別分類的特徵,並對它們進行索引 // 具有大於8個不同的值的特徵被視為連續。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(8).fit(vecDF) featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_8fbcad97fb60 featureIndexer.transform(vecDF).show(10,truncate=false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |indexedFeatures | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|[1.0,37.0,6.0,0.0,2.0,5.0,6.0,3.0]| |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |[0.0,27.0,4.0,0.0,3.0,2.0,5.0,3.0]| |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|[0.0,32.0,7.0,1.0,0.0,1.0,0.0,3.0]| |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|[1.0,57.0,7.0,1.0,4.0,5.0,5.0,4.0]| |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|[1.0,22.0,2.0,0.0,1.0,4.0,5.0,2.0]| |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |[0.0,32.0,3.0,0.0,1.0,4.0,4.0,4.0]| |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|[0.0,22.0,2.0,0.0,1.0,1.0,0.0,2.0]| |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|[1.0,57.0,7.0,1.0,1.0,2.0,3.0,3.0]| |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|[0.0,32.0,7.0,1.0,3.0,3.0,0.0,1.0]| |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |[1.0,22.0,3.0,0.0,3.0,2.0,3.0,4.0]| +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+ only showing top 10 rows // 將資料分為訓練和測試集(30%進行測試) val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3)) trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] // 訓練決策樹模型 val dt = new DecisionTreeClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setImpurity("entropy") // 不純度 .setMaxBins(100) // 離散化"連續特徵"的最大劃分數 .setMaxDepth(5) // 樹的最大深度 .setMinInfoGain(0.01) //一個節點分裂的最小資訊增益,值為[0,1] .setMinInstancesPerNode(10) //每個節點包含的最小樣本數 .setSeed(123456) // 將索引標籤轉換回原始標籤 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_2598e79a1d08 // Chain indexers and tree in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // 作出預測 val predictions = model.transform(testData) predictions: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 14 more fields] // 選擇幾個示例行展示 predictions.select("predictedLabel", "label", "features").show(10,truncate=false) +--------------+-----+-------------------------------------+ |predictedLabel|label|features | +--------------+-----+-------------------------------------+ |0.0 |0.0 |[0.0,22.0,0.125,0.0,2.0,14.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.125,0.0,2.0,16.0,6.0,3.0]| |0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,12.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]| |0.0 |0.0 |[0.0,22.0,0.75,0.0,2.0,16.0,5.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,14.0,5.0,4.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,3.0,16.0,6.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,17.0,5.0,5.0] | +--------------+-----+-------------------------------------+ only showing top 10 rows // 選擇(預測標籤,實際標籤),並計算測試誤差。 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) accuracy: Double = 0.7032967032967034 println("Test Error = " + (1.0 - accuracy)) Test Error = 0.29670329670329665 // 這裡的stages(2)中的“2”對應pipeline中的“dt”,將model強制轉換為DecisionTreeClassificationModel型別 val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] treeModel: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7a8baf97abe7) of depth 5 with 33 nodes treeModel.getLabelCol res53: String = indexedLabel treeModel.getFeaturesCol res54: String = indexedFeatures treeModel.featureImportances res55: org.apache.spark.ml.linalg.Vector = (8,[0,2,3,4,5,6,7],[0.0640344247735859,0.1052957011097811,0.05343872372010684,0.17367191628391196,0.20372870264756315,0.2063093687074741,0.1935211627575769]) treeModel.getPredictionCol res56: String = prediction treeModel.getProbabilityCol res57: String = probability treeModel.numClasses res58: Int = 2 treeModel.numFeatures res59: Int = 8 treeModel.depth res60: Int = 5 treeModel.numNodes res61: Int = 33 treeModel.getImpurity res62: String = entropy treeModel.getMaxBins res63: Int = 100 treeModel.getMaxDepth res64: Int = 5 treeModel.getMaxMemoryInMB res65: Int = 256 treeModel.getMinInfoGain res66: Double = 0.01 treeModel.getMinInstancesPerNode res67: Int = 10 // 檢視決策樹 println("Learned classification tree model:\n" + treeModel.toDebugString) Learned classification tree model: DecisionTreeClassificationModel (uid=dtc_7a8baf97abe7) of depth 5 with 33 nodes If (feature 2 in {0.0,1.0,2.0,3.0}) If (feature 5 in {3.0,6.0}) Predict: 0.0 Else (feature 5 not in {3.0,6.0}) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) If (feature 3 in {0.0}) If (feature 6 in {0.0,4.0,5.0}) Predict: 0.0 Else (feature 6 not in {0.0,4.0,5.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 2 not in {0.0,1.0,2.0,3.0}) If (feature 4 in {0.0,1.0,3.0,4.0}) If (feature 7 in {0.0,1.0,2.0}) If (feature 6 in {0.0,1.0,6.0}) If (feature 4 in {1.0,4.0}) Predict: 0.0 Else (feature 4 not in {1.0,4.0}) Predict: 0.0 Else (feature 6 not in {0.0,1.0,6.0}) If (feature 7 in {0.0,2.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0}) If (feature 5 in {0.0,1.0}) Predict: 0.0 Else (feature 5 not in {0.0,1.0}) If (feature 6 in {0.0,1.0,2.0,5.0,6.0}) Predict: 0.0 Else (feature 6 not in {0.0,1.0,2.0,5.0,6.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0,4.0}) If (feature 5 in {0.0,1.0,2.0,3.0,5.0,6.0}) If (feature 0 in {0.0}) If (feature 7 in {3.0}) Predict: 0.0 Else (feature 7 not in {3.0}) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 7 in {0.0,2.0,4.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) Predict: 1.0 Else (feature 5 not in {0.0,1.0,2.0,3.0,5.0,6.0}) Predict: 1.0