1. 程式人生 > >Spark 隨機森林演算法原理、原始碼分析及案例實戰

Spark 隨機森林演算法原理、原始碼分析及案例實戰

圖 1. Spark 與其它大資料處理工具的活躍程度比較
圖 1. Spark 與其它大資料處理工具的活躍程度比較

環境要求

  1. 作業系統:Linux,本文采用的 Ubuntu 10.04,大家可以根據自己的喜好使用自己擅長的 Linux 發行版
  2. Java 與 Scala 版本:Scala 2.10.4,Java 1.7
  3. Spark 叢集環境(3 臺):Hadoop 2.4.1+Spark 1.4.0,Spark 叢集搭建方式參見本人部落格:http://blog.csdn.net/lovehuangjiaju/article/details/46883973
  4. 原始碼閱讀與案例實戰環境:Intellij IDEA 14.1.4

決策樹

隨機森林演算法是機器學習、計算機視覺等領域內應用極為廣泛的一個演算法,它不僅可以用來做分類,也可用來做迴歸即預測,隨機森林機由多個決策樹構成,相比於單個決策樹演算法,它分類、預測效果更好,不容易出現過度擬合的情況。

隨機森林演算法基於決策樹,在正式講解隨機森林演算法之前,先來介紹決策樹的原理。決策樹是資料探勘與機器學習領域中一種非常重要的分類器,演算法通過訓練資料來構建一棵用於分類的樹,從而對未知資料進行高效分類。舉個相親的例子來說明什麼是決策樹、如何構建一個決策樹及如何利用決策樹進行分類,某相親網站通過調查相親歷史資料發現,女孩在實際相親時有如下表現:

表 1. 相親歷史資料表
序號 城市擁有房產 婚姻歷史(離過婚、單身) 年收入(單位:萬元) 見面(是、否)
1 單身 12
2 單身 15
3 離過婚 10
4 單身 18
5 離過婚 25
6 單身 50
7 離過婚 35
8 離過婚 40
9 單身 60
10 離過婚 17

通過表 1 所示歷史資料可以構建如下決策樹:

圖 2. 決策樹示意圖
圖 2. 決策樹示意圖

如果網站新註冊了一個使用者,他在城市無房產、年收入小於 35w 且離過婚,則可以預測女孩不會跟他見面。通過上面這個簡單的例子可以看出,決策樹對於現實生活具有很強的指導意義。通過該例子,我們也可以總結出決策樹的構建步驟:

  1. 將所有記錄看作是一個節點
  2. 遍歷每個變數的每種分割方式,找到最好的分割點
  3. 利用分割點將記錄分割成兩個子結點 C1 和 C2
  4. 對子結點 C1 和 C2 重複執行步驟 2)、3),直到滿足特定條件為止

在構建決策樹的過程中,最重要的是如何找到最好的分割點,那怎樣的分割點才算是最好的呢?如果一個分割點能夠將整個記錄準確地分為兩類,那該分割點就可以認為是最好的,此時被分成的兩類是相對來說是最“純”的,。例如前面的例子中“在城市擁有房產”可以將所有記錄分兩類,所有是“是”的都可以劃為一類,而“否”的則都被劃為另外一類。所有“是”劃分後的類是最“純“的,因為所有在城市擁有房產單身男士,不管他是否離過婚、年收入多少都會見面;而所有“否”劃分後的類,又被分為兩類,其中有見面的,也有不見面的,因此它不是很純,但對於整體記錄來講,它是最純的。

在上述例子當中,可以看到決策樹既可以處理連續型變數也可以處理名稱型變數,連續型變數如年收入,它可以用“>=”,“>”,“<”或“<=”作為分割條件,而名稱型變數如城市是否擁有房產,值是有限的集合如“是“、”否“兩種,它採用”=”作為分割條件。

在前面提到,尋找最好的分割點是通過量化分割後類的純度來確定的,目前有三種純度計算方式,分別是 Gini 不純度、熵(Entropy)及錯誤率,它們的公式定義如下:

公式中的 P(i) 表示記錄中第 i 類記錄數佔總記錄數的比例,例如前面的女孩相親例子可以根據見面或不見面分為兩類,見面的記錄佔比數為 P(1)=9/10,不見面的記錄佔比為 P(2)=1/10。上面的三個公式均是值越大表示越“不純”,值越小表示越“純”。實際中最常用的是 Gini 不純度公式,後面的例子也將採用該公式進行純度計算。

決策樹的構建是一個遞迴的過程,理想情況下所有的記錄都能被精確分類,即生成決策樹葉節點都有確定的型別,但現實這種條件往往很難滿足,這使得決策樹在構建時可能很難停止。即使構建完成,也常常會使得最終的節點數過多,從而導致過度擬合(overfitting),因此在實際應用中需要設定停止條件,當達到停止條件時,直接停止決策樹的構建。但這仍然不能完全解決過度擬合問題,過度擬合的典型表現是決策樹對訓練資料錯誤率很低,而對測試資料其錯誤率卻非常高。過度擬合常見原因有:(1)訓練資料中存在噪聲;(2)資料不具有代表性。過度擬合的典型表現是決策樹的節點過多,因此實際中常常需要對構建好的決策樹進行枝葉裁剪(Prune Tree),但它不能解決根本問題,隨機森林演算法的出現能夠較好地解決過度擬合問題。

隨機森林演算法

由多個決策樹構成的森林,演算法分類結果由這些決策樹投票得到,決策樹在生成的過程當中分別在行方向和列方向上新增隨機過程,行方向上構建決策樹時採用放回抽樣(bootstraping)得到訓練資料,列方向上採用無放回隨機抽樣得到特徵子集,並據此得到其最優切分點,這便是隨機森林演算法的基本原理。圖 3 給出了隨機森林演算法分類原理,從圖中可以看到,隨機森林是一個組合模型,內部仍然是基於決策樹,同單一的決策樹分類不同的是,隨機森林通過多個決策樹投票結果進行分類,演算法不容易出現過度擬合問題。

圖 3. 隨機森林示意圖
圖 3. 隨機森林示意圖

隨機森林在分散式環境下的優化策略

隨機森林演算法在單機環境下很容易實現,但在分散式環境下特別是在 Spark 平臺上,傳統單機形式的迭代方式必須要進行相應改進才能適用於分散式環境,這是因為在分散式環境下,資料也是分散式的(如圖 5 所示),演算法設計不得當會生成大量的 IO 操作,例如頻繁的網路資料傳輸,從而影響演算法效率。

圖 4. 單機環境下資料儲存
圖 4. 單機環境下資料儲存
圖 5. 分散式環境下資料儲存
圖 5. 分散式環境下資料儲存

因此,在 Spark 上進行隨機森林演算法的實現,需要進行一定的優化,Spark 中的隨機森林演算法主要實現了三個優化策略:

  1. 切分點抽樣統計,如圖 6 所示。在單機環境下的決策樹對連續變數進行切分點選擇時,一般是通過對特徵點進行排序,然後取相鄰兩個數之間的點作為切分點,這在單機環境下是可行的,但如果在分散式環境下如此操作的話,會帶來大量的網路傳輸操作,特別是當資料量達到 PB 級時,演算法效率將極為低下。為避免該問題,Spark 中的隨機森林在構建決策樹時,會對各分割槽採用一定的子特徵策略進行抽樣,然後生成各個分割槽的統計資料,並最終得到切分點。
  2. 特徵裝箱(Binning),如圖 7 所示。決策樹的構建過程就是對特徵的取值不斷進行劃分的過程,對於離散的特徵,如果有 M 個值,最多
    個劃分,如果值是有序的,那麼就最多 M-1 個劃分。比如年齡特徵,有老,中,少 3 個值,如果無序有
    個,即 3 種劃分:老|中,少;老,中|少;老,少|中;如果是有序的,即按老,中,少的序,那麼只有 m-1 個,即 2 種劃分,老|中,少;老,中|少。對於連續的特徵,其實就是進行範圍劃分,而劃分的點就是 split(切分點),劃分出的區間就是 bin。對於連續特徵,理論上 split 是無數的,在分佈環境下不可能取出所有的值,因此它採用的是(1)中的切點抽樣統計方法。
  3. 逐層訓練(level-wise training),如圖 8 所示。單機版本的決策數生成過程是通過遞迴呼叫(本質上是深度優先)的方式構造樹,在構造樹的同時,需要移動資料,將同一個子節點的資料移動到一起。此方法在分散式資料結構上無法有效的執行,而且也無法執行,因為資料太大,無法放在一起,所以在分散式環境下采用的策略是逐層構建樹節點(本質上是廣度優先),這樣遍歷所有資料的次數等於所有樹中的最大層數。每次遍歷時,只需要計算每個節點所有切分點統計引數,遍歷完後,根據節點的特徵劃分,決定是否切分,以及如何切分。
圖 6. 切分點抽樣統計
圖 6. 切分點抽樣統計
圖 7. 特徵裝箱
圖 7. 特徵裝箱
圖 8. 逐層訓練
圖 8. 逐層訓練

隨機森林演算法原始碼分析

在對決策樹、隨機森林演算法原理及 Spark 上的優化策略的理解基礎上,本節將對 Spark MLlib 中的隨機森林演算法原始碼進行分析。首先給出了官網上的演算法使用 demo,然後再深入到對應方法原始碼中,對實現原理進行分析。

清單 1. 隨機森林使用 demo
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
// 載入資料
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 將資料隨機分配為兩份,一份用於訓練,一份用於測試
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 隨機森林訓練引數設定
//分類數
val numClasses = 2
// categoricalFeaturesInfo 為空,意味著所有的特徵為連續型變數
val categoricalFeaturesInfo = Map[Int, Int]()
//樹的個數
val numTrees = 3 
//特徵子集取樣策略,auto 表示演算法自主選取
val featureSubsetStrategy = "auto" 
//純度計算
val impurity = "gini"
//樹的最大層次
val maxDepth = 4
//特徵最大裝箱數
val maxBins = 32
//訓練隨機森林分類器,trainClassifier 返回的是 RandomForestModel 物件
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
 numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

// 測試資料評價訓練好的分類器並計算錯誤率
val labelAndPreds = testData.map { point =>
 val prediction = model.predict(point.features)
 (point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification forest model:\n" + model.toDebugString)

// 將訓練後的隨機森林模型持久化
model.save(sc, "myModelPath")
//載入隨機森林模型到記憶體
val sameModel = RandomForestModel.load(sc, "myModelPath")

通過上述樣例程式碼可以看到,從使用者的角度來看,隨機森林中關鍵的類是 org.apache.spark.mllib.tree.RandomForest、org.apache.spark.mllib.tree.model.RandomForestModel 這兩個類,它們提供了隨機森林具體的 trainClassifier 和 predict 函式。

從上面的 demo 中可以看到,訓練隨機森林演算法採用的是 RandomForest 的伴生物件中的 trainClassifier 方法,其原始碼(為方便理解,保留方法前面的註釋及引數說明)如下:

清單 2. 核心原始碼分析 1
/**
 * Method to train a decision tree model for binary or multiclass classification.
 *
 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
 * Labels should take values {0, 1, ..., numClasses-1}.
 * @param numClasses number of classes for classification.
 * @param categoricalFeaturesInfo Map storing arity of categorical features.
 * E.g., an entry (n -> k) indicates that feature n is categorical
 * with k categories indexed from 0: {0, 1, ..., k-1}.
 * @param numTrees Number of trees in the random forest.
 * @param featureSubsetStrategy Number of features to consider for splits at each node.
 * Supported: "auto", "all", "sqrt", "log2", "onethird".
 * If "auto" is set, this parameter is set based on numTrees:
 * if numTrees == 1, set to "all";
 * if numTrees > 1 (forest) set to "sqrt".
 * @param impurity Criterion used for information gain calculation.
 * Supported values: "gini" (recommended) or "entropy".
 * @param maxDepth Maximum depth of the tree.
 * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
 * (suggested value: 4)
 * @param maxBins maximum number of bins used for splitting features
 * (suggested value: 100)
 * @param seed Random seed for bootstrapping and choosing feature subsets.
 * @return a random forest model that can be used for prediction
 */
 def trainClassifier(
 input: RDD[LabeledPoint],
 numClasses: Int,
 categoricalFeaturesInfo: Map[Int, Int],
 numTrees: Int,
 featureSubsetStrategy: String,
 impurity: String,
 maxDepth: Int,
 maxBins: Int,
 seed: Int = Utils.random.nextInt()): RandomForestModel = {
 val impurityType = Impurities.fromString(impurity)
 val strategy = new Strategy(Classification, impurityType, maxDepth,
 numClasses, maxBins, Sort, categoricalFeaturesInfo)
 //呼叫的是過載的另外一個 trainClassifier
 trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)
 }

過載後 trainClassifier 方法程式碼如下:

清單 3. 核心原始碼分析 2
/**
 * Method to train a decision tree model for binary or multiclass classification.
 *
 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
 * Labels should take values {0, 1, ..., numClasses-1}.
 * @param strategy Parameters for training each tree in the forest.
 * @param numTrees Number of trees in the random forest.
 * @param featureSubsetStrategy Number of features to consider for splits at each node.
 * Supported: "auto", "all", "sqrt", "log2", "onethird".
 * If "auto" is set, this parameter is set based on numTrees:
 * if numTrees == 1, set to "all";
 * if numTrees > 1 (forest) set to "sqrt".
 * @param seed Random seed for bootstrapping and choosing feature subsets.
 * @return a random forest model that can be used for prediction
 */
 def trainClassifier(
 input: RDD[LabeledPoint],
 strategy: Strategy,
 numTrees: Int,
 featureSubsetStrategy: String,
 seed: Int): RandomForestModel = {
 require(strategy.algo == Classification,
 s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
 //在該方法中建立 RandomForest 物件
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
//再呼叫其 run 方法,傳入的引數是型別 RDD[LabeledPoint],方法返回的是 RandomForestModel 例項
 rf.run(input)
 }

進入 RandomForest 中的 run 方法,其程式碼如下:

清單 4. 核心原始碼分析 3
/**
 * Method to train a decision tree model over an RDD
 * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
 * @return a random forest model that can be used for prediction
 */
 def run(input: RDD[LabeledPoint]): RandomForestModel = {

 val timer = new TimeTracker()

 timer.start("total")

 timer.start("init")
 
val retaggedInput = input.retag(classOf[LabeledPoint])
//建立決策樹的元資料資訊(分裂點位置、箱子數及各箱子包含特徵屬性的值等等)
 val metadata =
 DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
 logDebug("algo = " + strategy.algo)
 logDebug("numTrees = " + numTrees)
 logDebug("seed = " + seed)
 logDebug("maxBins = " + metadata.maxBins)
 logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
 logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
 logDebug("subsamplingRate = " + strategy.subsamplingRate)

 // Find the splits and the corresponding bins (interval between the splits) using a sample
 // of the input data.
timer.start("findSplitsBins")
//找到切分點(splits)及箱子資訊(Bins)
//對於連續型特徵,利用切分點抽樣統計簡化計算
//對於名稱型特徵,如果是無序的,則最多有個 splits=2^(numBins-1)-1 劃分
//如果是有序的,則最多有 splits=numBins-1 個劃分
 val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
 timer.stop("findSplitsBins")
 logDebug("numBins: feature: number of bins")
 logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
 s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
 }.mkString("\n"))

 // Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
//轉換成樹形的 RDD 型別,轉換後,所有樣本點已經按分裂點條件分到了各自的箱子中 
 val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)

 val withReplacement = if (numTrees > 1) true else false

// convertToBaggedRDD 方法使得每棵樹就是樣本的一個子集 
 val baggedInput
 = BaggedPoint.convertToBaggedRDD(treeInput,
 strategy.subsamplingRate, numTrees,
 withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)

 // depth of the decision tree
 val maxDepth = strategy.maxDepth
 require(maxDepth <= 30,
 s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")

 // Max memory usage for aggregates
 // TODO: Calculate memory usage more precisely.
 val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
 val maxMemoryPerNode = {
 val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
 // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
 Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
 .take(metadata.numFeaturesPerNode).map(_._2))
 } else {
 None
 }
 //計算聚合操作時節點的記憶體 
 RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
}
require(maxMemoryPerNode <= maxMemoryUsage,
 s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
 " which is too small for the given features." +
 s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")

 timer.stop("init")

 /*
 * The main idea here is to perform group-wise training of the decision tree nodes thus
 * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
 * Each data sample is handled by a particular node (or it reaches a leaf and is not used
 * in lower levels).
 */

 // Create an RDD of node Id cache.
// At first, all the rows belong to the root nodes (node Id == 1).
//節點是否使用快取,節點 ID 從 1 開始,1 即為這顆樹的根節點,左節點為 2,右節點為 3,依次遞增下去 
 val nodeIdCache = if (strategy.useNodeIdCache) {
 Some(NodeIdCache.init(
 data = baggedInput,
 numTrees = numTrees,
 checkpointInterval = strategy.checkpointInterval,
 initVal = 1))
 } else {
 None
 }

 // FIFO queue of nodes to train: (treeIndex, node)
 val nodeQueue = new mutable.Queue[(Int, Node)]()

 val rng = new scala.util.Random()
 rng.setSeed(seed)

// Allocate and queue root nodes.
//建立樹的根節點
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
//將(樹的索引,數的根節點)入隊,樹索引從 0 開始,根節點從 1 開始 
 Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

 while (nodeQueue.nonEmpty) {
 // Collect some nodes to split, and choose features for each node (if subsampling).
 // Each group of nodes may come from one or multiple trees, and at multiple levels.
 // 取得每個樹所有需要切分的節點
 val (nodesForGroup, treeToNodeToIndexInfo) =
 RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
 // Sanity check (should never occur):
 assert(nodesForGroup.size > 0,
 s"RandomForest selected empty nodesForGroup. Error for unknown reason.")

 // Choose node splits, and enqueue new nodes as needed.
 timer.start("findBestSplits")
 //找出最優切點
 DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
 treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
 timer.stop("findBestSplits")
 }

 baggedInput.unpersist()

 timer.stop("total")

 logInfo("Internal timing for DecisionTree:")
 logInfo(s"$timer")

 // Delete any remaining checkpoints used for node Id cache.
 if (nodeIdCache.nonEmpty) {
 try {
 nodeIdCache.get.deleteAllCheckpoints()
 } catch {
 case e: IOException =>
 logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
 }
 }

 val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
 new RandomForestModel(strategy.algo, trees)
 }

}

上面給出的是 RandomForest 類中的核心方法 run 的程式碼,在確定切分點及箱子資訊的時候呼叫了 DecisionTree.findSplitsBins 方法,跳入該方法,可以看到如下程式碼:

清單 5. 核心原始碼分析 4
 /**
 * Returns splits and bins for decision tree calculation.
 * Continuous and categorical features are handled differently.
 *
 * Continuous features:
 * For each feature, there are numBins - 1 possible splits representing the possible binary
 * decisions at each node in the tree.
 * This finds locations (feature values) for splits using a subsample of the data.
 *
 * Categorical features:
 * For each feature, there is 1 bin per split.
 * Splits and bins are handled in 2 ways:
 * (a) "unordered features"
 * For multiclass classification with a low-arity feature
 * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
 * the feature is split based on subsets of categories.
 * (b) "ordered features"
 * For regression and binary classification,
 * and for multiclass classification with a high-arity feature,
 * there is one bin per category.
 *
 * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
 * @param metadata Learning and dataset metadata
 * @return A tuple of (splits, bins).
 * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
 * of size (numFeatures, numSplits).
 * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
 * of size (numFeatures, numBins).
 */
 protected[tree] def findSplitsBins(
 input: RDD[LabeledPoint],
 metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {

 logDebug("isMulticlass = " + metadata.isMulticlass)

 val numFeatures = metadata.numFeatures

// Sample the input only if there are continuous features.
// 判斷特徵中是否存在連續特徵
 val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
 val sampledInput = if (hasContinuousFeatures) {
 // Calculate the number of samples for approximate quantile calculation.
 //取樣樣本數量,最少應該為 10000 個
 val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
 //計算取樣比例
 val fraction = if (requiredSamples < metadata.numExamples) {
 requiredSamples.toDouble / metadata.numExamples
 } else {
 1.0
 }
 logDebug("fraction of data used for calculating quantiles = " + fraction)
 input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
} else {
 //如果為離散特徵,則構建一個空陣列(即無需取樣)
 new Array[LabeledPoint](0)
 }

 // //分裂點策略,目前 Spark 中只實現了一種策略:排序 Sort 
 metadata.quantileStrategy match {
 case Sort =>
 //每個特徵分別對應一組切分點位置
 val splits = new Array[Array[Split]](numFeatures)
 //存放切分點位置對應的箱子資訊
 val bins = new Array[Array[Bin]](numFeatures)

 // Find all splits.
 // Iterate over all features.
 var featureIndex = 0
 //遍歷所有的特徵
 while (featureIndex < numFeatures) {
 //特徵為連續的情況
 if (metadata.isContinuous(featureIndex)) {
 val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
 // findSplitsForContinuousFeature 返回連續特徵的所有切分位置
 val featureSplits = findSplitsForContinuousFeature(featureSamples,
 metadata, featureIndex)

 val numSplits = featureSplits.length
 //連續特徵的箱子數為切分點個數+1
 val numBins = numSplits + 1
 logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")

 //切分點陣列及特徵箱子陣列
 splits(featureIndex) = new Array[Split](numSplits)
 bins(featureIndex) = new Array[Bin](numBins)

 var splitIndex = 0
 //遍歷切分點
 while (splitIndex < numSplits) {
 //獲取切分點對應的值,由於是排過序的,因此它具有閾值屬性 
 val threshold = featureSplits(splitIndex)
 //儲存對應特徵所有的切分點位置資訊
 splits(featureIndex)(splitIndex) =
 new Split(featureIndex, threshold, Continuous, List())
 splitIndex += 1
 }
 //採用最小閾值 Double.MinValue 作為最左邊的分裂位置並進行裝箱
 bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
 splits(featureIndex)(0), Continuous, Double.MinValue)

 splitIndex = 1
 //除最後一個箱子外剩餘箱子的計算,各箱子裡將存放的是兩個切分點位置閾值區間的屬性值 
 while (splitIndex < numSplits) {
 bins(featureIndex)(splitIndex) =
 new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
 Continuous, Double.MinValue)
 splitIndex += 1
 }
 //最後一個箱子的計算採用最大閾值 Double.MaxValue 作為最右邊的切分位置
 bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
 new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
 } else { //特徵為離散情況時的計算
 val numSplits = metadata.numSplits(featureIndex)
 val numBins = metadata.numBins(featureIndex)
 // Categorical feature
 //離線屬性的個數
 val featureArity = metadata.featureArity(featureIndex)
 //特徵無序時的處理方式
 if (metadata.isUnordered(featureIndex)) {
 // Unordered features
 // 2^(maxFeatureValue - 1) - 1 combinations
 splits(featureIndex) = new Array[Split](numSplits)
 var splitIndex = 0
 while (splitIndex < numSplits) {
 //提取特徵的屬性值,返回集合包含其中一個或多個的離散屬性值 
 val categories: List[Double] =
 extractMultiClassCategories(splitIndex + 1, featureArity)
 splits(featureIndex)(splitIndex) =
 new Split(featureIndex, Double.MinValue, Categorical, categories)
 splitIndex += 1
 }
 } else {
 //有序特徵無需處理,箱子與特徵值對應
 // Ordered features
 // Bins correspond to feature values, so we do not need to compute splits or bins
 // beforehand. Splits are constructed as needed during training.
 splits(featureIndex) = new Array[Split](0)
 }
 // For ordered features, bins correspond to feature values.
 // For unordered categorical features, there is no need to construct the bins.
 // since there is a one-to-one correspondence between the splits and the bins.
 bins(featureIndex) = new Array[Bin](0)
 }
 featureIndex += 1
 }
 (splits, bins)
 case MinMax =>
 throw new UnsupportedOperationException("minmax not supported yet.")
 case ApproxHist =>
 throw new UnsupportedOperationException("approximate histogram not supported yet.")
 }
}

除 findSplitsBins 方法外,還有一個非常重要的 DecisionTree.findBestSplits() 方法,用於最優切分點的查詢,該方法中的關鍵是對 binsToBestSplit 方法的呼叫,其 binsToBestSplit 方法程式碼如下:

清單 6. 核心原始碼分析 5
/**
 * Find the best split for a node.
 * @param binAggregates Bin statistics.
 * @return tuple for best split: (Split, information gain, prediction at node)
 */
 private def binsToBestSplit(
 binAggregates: DTStatsAggregator, // DTStatsAggregator,其中引用了 ImpurityAggregator,給出計算不純度 impurity 的邏輯
 splits: Array[Array[Split]],
 featuresForNode: Option[Array[Int]],
 node: Node): (Split, InformationGainStats, Predict) = {

 // calculate predict and impurity if current node is top node
 val level = Node.indexToLevel(node.id)
 var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
 None
 } else {
 Some((node.predict, node.impurity))
 }

// For each (feature, split), calculate the gain, and select the best (feature, split).
//對各特徵及切分點,計算其資訊增益並從中選擇最優 (feature, split)
 val (bestSplit, bestSplitStats) =
 Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
 val featureIndex = if (featuresForNode.nonEmpty) {
 featuresForNode.get.apply(featureIndexIdx)
 } else {
 featureIndexIdx
 }
 val numSplits = binAggregates.metadata.numSplits(featureIndex)
 //特徵為連續值的情況
 if (binAggregates.metadata.isContinuous(featureIndex)) {
 // Cumulative sum (scanLeft) of bin statistics.
 // Afterwards, binAggregates for a bin is the sum of aggregates for
 // that bin + all preceding bins.
 val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
 var splitIndex = 0
 while (splitIndex < numSplits) {
 binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
 splitIndex += 1
 }
 // Find best split.
 val (bestFeatureSplitIndex, bestFeatureGainStats) =
 Range(0, numSplits).map { case splitIdx =>
 //計算 leftChild 及 rightChild 子節點的 impurity
 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
 rightChildStats.subtract(leftChildStats)
 //求 impurity 的預測值,採用的是平均值計算
 predictWithImpurity = Some(predictWithImpurity.getOrElse(
 calculatePredictImpurity(leftChildStats, rightChildStats)))
 //求資訊增益 information gain 值,用於評估切分點是否最優
 val gainStats = calculateGainForSplit(leftChildStats,
 rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
 (splitIdx, gainStats)
 }.maxBy(_._2.gain)
 (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
 } else if (binAggregates.metadata.isUnordered(featureIndex)) { //無序離散特徵時的情況
 // Unordered categorical feature
 val (leftChildOffset, rightChildOffset) =
 binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
 val (bestFeatureSplitIndex, bestFeatureGainStats) =
 Range(0, numSplits).map { splitIndex =>
 val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
 val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
 predictWithImpurity = Some(predictWithImpurity.getOrElse(
 calculatePredictImpurity(leftChildStats, rightChildStats)))
 val gainStats = calculateGainForSplit(leftChildStats,
 rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
 (splitIndex, gainStats)
 }.maxBy(_._2.gain)
 (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
 } else { //有序離散特徵時的情況
 // Ordered categorical feature
 val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
 val numBins = binAggregates.metadata.numBins(featureIndex)

 /* Each bin is one category (feature value).
 * The bins are ordered based on centroidForCategories, and this ordering determines which
 * splits are considered. (With K categories, we consider K - 1 possible splits.)
 *
 * centroidForCategories is a list: (category, centroid)
 */
 //多元分類時的情況
 val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
 // For categorical variables in multiclass classification,
 // the bins are ordered by the impurity of their corresponding labels.
 Range(0, numBins).map { case featureValue =>
 val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
 val centroid = if (categoryStats.count != 0) {
 // impurity 求的就是均方差
 categoryStats.calculate()
 } else {
 Double.MaxValue
 }
 (featureValue, centroid)
 }
 } else { // 迴歸或二元分類時的情況 regression or binary classification
 // For categorical variables in regression and binary classification,
 // the bins are ordered by the centroid of their corresponding labels.
 Range(0, numBins).map { case featureValue =>
 val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
 val centroid = if (categoryStats.count != 0) {
 //求的就是平均值作為 impurity
 categoryStats.predict
 } else {
 Double.MaxValue
 }
 (featureValue, centroid)
 }
 }

 logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

 // bins sorted by centroids
 val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)

 logDebug("Sorted centroids for categorical variable = " +
 categoriesSortedByCentroid.mkString(","))

 // Cumulative sum (scanLeft) of bin statistics.
 // Afterwards, binAggregates for a bin is the sum of aggregates for
 // that bin + all preceding bins.
 var splitIndex = 0
 while (splitIndex < numSplits) {
 val currentCategory = categoriesSortedByCentroid(splitIndex)._1
 val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
 //將兩個箱子的狀態資訊進行合併
 binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
 splitIndex += 1
 }
 // lastCategory = index of bin with total aggregates for this (node, feature)
 val lastCategory = categoriesSortedByCentroid.last._1
 // Find best split.
 //通過資訊增益值選擇最優切分點
 val (bestFeatureSplitIndex, bestFeatureGainStats) =
 Range(0, numSplits).map { splitIndex =>
 val featureValue = categoriesSortedByCentroid(splitIndex)._1
 val leftChildStats =
 binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
 val rightChildStats =
 binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
 rightChildStats.subtract(leftChildStats)
 predictWithImpurity = Some(predictWithImpurity.getOrElse(
 calculatePredictImpurity(leftChildStats, rightChildStats)))
 val gainStats = calculateGainForSplit(leftChildStats,
 rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
 (splitIndex, gainStats)
 }.maxBy(_._2.gain)
 val categoriesForSplit =
 categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
 val bestFeatureSplit =
 new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
 (bestFeatureSplit, bestFeatureGainStats)
 }
 }.maxBy(_._2.gain)

 (bestSplit, bestSplitStats, predictWithImpurity.get._1)
}

上述程式碼給出了一個完整的隨機森林構造過程核心程式碼,我們也提到 RandomForest 中的 run 方法返回的是 RandomForestModel,該類的程式碼如下:

清單 7. 核心原始碼分析 6
/**
 * :: Experimental ::
 * Represents a random forest model.
 *
 * @param algo algorithm for the ensemble model, either Classification or Regression
 * @param trees tree ensembles
 */
// RandomForestModel 擴充套件自 TreeEnsembleModel
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
 extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
 combiningStrategy = if (algo == Classification) Vote else Average)
 with Saveable {

 require(trees.forall(_.algo == algo))
 //將訓練好的模型持久化
 override def save(sc: SparkContext, path: String): Unit = {
 TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
 RandomForestModel.SaveLoadV1_0.thisClassName)
 }

 override protected def formatVersion: String = RandomForestModel.formatVersion
}

object RandomForestModel extends Loader[RandomForestModel] {

 private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
 //將訓練好的模型載入到記憶體
 override def load(sc: SparkContext, path: String): RandomForestModel = {
 val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
 val classNameV1_0 = SaveLoadV1_0.thisClassName
 (loadedClassName, version) match {
 case (className, "1.0") if className == classNameV1_0 =>
 val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
 assert(metadata.treeWeights.forall(_ == 1.0))
 val trees =
 TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
 new RandomForestModel(Algo.fromString(metadata.algo), trees)
 case _ => throw new Exception(s"RandomForestModel.load did not recognize model" +
 s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
 s" ($classNameV1_0, 1.0)")
 }
 }

 private object SaveLoadV1_0 {
 // Hard-code class name string in case it changes in the future
 def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel"
 }

}

在利用隨機森林進行預測時,呼叫的 predict 方法擴充套件自 TreeEnsembleModel,它是樹結構組合模型的表示,除隨機森林外還包括 Gradient-Boosted Trees (GBTs),其部分核心程式碼如下:

清單 8. 核心原始碼分析 7
/**
 * Represents a tree ensemble model.
 *
 * @param algo algorithm for the ensemble model, either Classification or Regression
 * @param trees tree ensembles
 * @param treeWeights tree ensemble weights
 * @param combiningStrategy strategy for combining the predictions, not used for regression.
 */
private[tree] sealed class TreeEnsembleModel(
 protected val algo: Algo,
 protected val trees: Array[DecisionTreeModel],
 protected val treeWeights: Array[Double],
 protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {

 require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.")
 //其它程式碼省略

 //通過投票實現最終的分類
 /**
 * Classifies a single data point based on (weighted) majority votes.
 */
 private def predictByVoting(features: Vector): Double = {
 val votes = mutable.Map.empty[Int, Double]
 trees.view.zip(treeWeights).foreach { case (tree, weight) =>
 val prediction = tree.predict(features).toInt
 votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
 }
 votes.maxBy(_._2)._1
 }

 
 /**
 * Predict values for a single data point using the model trained.
 *
 * @param features array representing a single data point
 * @return predicted category from the trained model
 */
 //不同的策略採用不同的預測方法
 def findSplitsBins(features: Vector): Double = {
 (algo, combiningStrategy) match {
 case (Regression, Sum) =>
 predictBySumming(features)
 case (Regression, Average) =>
 predictBySumming(features) / sumWeights
 case (Classification, Sum) => // binary classification
 val prediction = predictBySumming(features)
 // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
 if (prediction > 0.0) 1.0 else 0.0
 //隨機森林對應 predictByVoting 方法
 case (Classification, Vote) =>
 predictByVoting(features)
 case _ =>
 throw new IllegalArgumentException(
 "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " +
 s"($algo, $combiningStrategy).")
 }
 }

 // predict 方法的具體實現
 /**
 * Predict values for the given data set.
 *
 * @param features RDD representing data points to be predicted
 * @return RDD[Double] where each entry contains the corresponding prediction
 */
 def predict(features: RDD[Vector]): RDD[Double] = features.map(x => findSplitsBins (x))
 

 //其它程式碼省略
}

通過上述核心程式碼分析,我們已經理解了整個隨機森林演算法的內部機制,下一小節給出其實際使用案例。

隨機森林演算法案例實戰

本節將通過介紹一個案例來說明隨機森林的具體應用。一般銀行在貨款之前都需要對客戶的還款能力進行評估,但如果客戶資料量比較龐大,信貸稽核人員的壓力會非常大,此時常常會希望通過計算機來進行輔助決策。隨機森林演算法可以在該場景下使用,例如可以將原有的歷史資料輸入到隨機森林演算法當中進行資料訓練,利用訓練後得到的模型對新的客戶資料進行分類,這樣便可以過濾掉大量的無還款能力的客戶,如此便能極大地減少信貨稽核人員的工作量。

圖 9. Spark 叢集執行效果圖
圖 9. Spark 叢集執行效果圖

假設存在下列信貸使用者歷史還款記錄:

表 2. 信貸使用者歷史還款資料表
記錄號 是否擁有房產(是/否) 婚姻情況(單身、已婚、離婚) 年收入(單位:萬元) 是否具備還款能力(是、否)
10001 已婚 10
10002 單身 8
10003 單身 13
…… …. ….. …. ……
11000 單身 8

上述信貸使用者歷史還款記錄被格式化為 label index1:feature1 index2:feature2 index3:feature3 這種格式,例如上表中的第一條記錄將被格式化為 0 1:0 2:1 3:10,各欄位含義如下:

是否具備還款能力 是否擁有房產 婚姻情況,0 表示單身、 年收入

0 表示是,1 表示否 0 表示否,1 表示是 1 表示已婚、2 表示離婚 填入實際數字

0 1:0 2:1 3:10

將表中所有資料轉換後,儲存為 sample_data.txt,該資料用於訓練隨機森林。測試資料為:

表 3. 測試資料表
是否擁有房產(是/否) 婚姻情況(單身、已婚、離婚) 年收入(單位:萬元)
已婚 12

如果隨機森林模型訓練正確的話,上面這條使用者資料得到的結果應該是具備還款能力,為方便後期處理,我們將其儲存為 input.txt,內容為:

0 1:0 2:1 3:12

將 sample_data.txt、input.txt 利用 hadoop fs –put input.txt sample_data.txt /data 上傳到 HDFS 中的/data 目錄當中,再編寫如清單 9 所示的程式碼進行驗證

清單 9. 判斷客戶是否具有還貸能力
package cn.ml

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.Vectors

object RandomForstExample {
 def main(args: Array[String]) {
 val sparkConf = new SparkConf().setAppName("RandomForestExample").
          setMaster("spark://sparkmaster:7077")
 val sc = new SparkContext(sparkConf)

 val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "/data/sample_data.txt")

 val numClasses = 2 
 val featureSubsetStrategy = "auto"
 val numTrees = 3
 val model: RandomForestModel =RandomForest.trainClassifier(
                    data, Strategy.defaultStrategy("classification"),numTrees, 
 featureSubsetStrategy,new java.util.Random().nextInt())
 
 val input: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "/data/input.txt")

 val predictResult = input.map { point =>
 val prediction = model.predict(point.features)
 (point.label, prediction)
}
//列印輸出結果,在 spark-shell 上執行時使用
 predictResult.collect()
 //將結果儲存到 hdfs //predictResult.saveAsTextFile("/data/predictResult")
 sc.stop()

 }
}

上述程式碼既可以打包後利用 spark-summit 提交到伺服器上執行,也可以在 spark-shell 上執行檢視結果. 圖 10 給出了訓練得到的RadomForest 模型結果,圖 11 給出了 RandomForest 模型預測得到的結果,可以看到預測結果與預期是一致的。

圖 10. 訓練得到的 RadomForest 模型
圖 10. 訓練得到的 RadomForest 模型
圖 11. collect 方法返回的結果
圖 11. collect 方法返回的結果

相關推薦

no