1. 程式人生 > >Spark機器學習(6):決策樹算法

Spark機器學習(6):決策樹算法

projects 信息 txt .cn import n) .com util seq

1. 決策樹基本知識

決策樹就是通過一系列規則對數據進行分類的一種算法,可以分為分類樹和回歸樹兩類,分類樹處理離散變量的,回歸樹是處理連續變量。

樣本一般都有很多個特征,有的特征對分類起很大的作用,有的特征對分類作用很小,甚至沒有作用。如決定是否對一個人貸款是,這個人的信用記錄、收入等就是主要的判斷依據,而性別、婚姻狀況等等就是次要的判斷依據。決策樹構建的過程,就是根據特征的決定性程度,先使用決定性程度高的特征分類,再使用決定性程度低的特征分類,這樣構建出一棵倒立的樹,就是我們需要的決策樹模型,可以用來對數據進行分類。

決策樹學習的過程可以分為三個步驟:1)特征選擇,即從眾多特征中選擇出一個作為當前節點的分類標準;2)決策樹生成,從上到下構建節點;3)剪枝,為了預防和消除過擬合,需要對決策樹剪枝。

2. 決策樹算法

主要的決策樹算法包括ID3、C4.5和CART。

ID3把信息增益作為選擇特征的標準。由於取值較多的特征(如學號)的信息增益比較大,這種算法會偏向於取值較多的特征。而且該算法只能用於離散型的數據,優點是不需要剪枝。

C4.5和ID3比較類似,區別在於使用信息增益比替代信息增益作為選擇特征的標準,因此比ID3更加科學,並且可以用於連續型的數據,但是需要剪枝。

CART(Classification And Regression Tree)采用的是Gini作為選擇的標準。Gini越大,說明不純度越大,這個特征就越不好。

3. MLlib的決策樹算法

MLlib的決策樹算法使用的隨機森林RandomForest的方法,不過並不是真正的隨機森林,因為實際上只有一棵決策樹。

直接上代碼:

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

/**
  * Created by Administrator on 2017/7/6.
  
*/ object DecisionTreeTest { def main(args: Array[String]): Unit = { // 設置運行環境 val conf = new SparkConf().setAppName("Decision Tree") .setMaster("spark://master:7077").setJars(Seq("E:\\Intellij\\Projects\\MachineLearning\\MachineLearning.jar")) val sc = new SparkContext(conf) Logger.getRootLogger.setLevel(Level.WARN) // 讀取樣本數據並解析 val dataRDD = MLUtils.loadLibSVMFile(sc, "hdfs://master:9000/ml/data/sample_dt_data.txt") // 樣本數據劃分,訓練樣本占0.8,測試樣本占0.2 val dataParts = dataRDD.randomSplit(Array(0.8, 0.2)) val trainRDD = dataParts(0) val testRDD = dataParts(1) // 決策樹參數 val numClasses = 5 val categoricalFeaturesInfo = Map[Int, Int]() val impurity = "gini" val maxDepth = 5 val maxBins = 32 // 建立決策樹模型並訓練 val model = DecisionTree.trainClassifier(trainRDD, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins) // 對測試樣本進行測試 val predictionAndLabel = testRDD.map { point => val score = model.predict(point.features) (score, point.label, point.features) } val showPredict = predictionAndLabel.take(50) println("Prediction" + "\t" + "Label" + "\t" + "Data") for (i <- 0 to showPredict.length - 1) { println(showPredict(i)._1 + "\t" + showPredict(i)._2 + "\t" + showPredict(i)._3) } // 誤差計算 val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / testRDD.count() println("Accuracy = " + accuracy) // 保存模型 val ModelPath = "hdfs://master:9000/ml/model/Decision_Tree_Model" model.save(sc, ModelPath) val sameModel = DecisionTreeModel.load(sc, ModelPath) }

運行結果:

技術分享

Spark機器學習(6):決策樹算法