1. 程式人生 > >Apache Spark MLlib學習筆記(六)MLlib決策樹類演算法原始碼解析 2

Apache Spark MLlib學習筆記(六)MLlib決策樹類演算法原始碼解析 2

上篇說道建立分類決策樹模型呼叫了trainClassifier方法,這章分析trainClassifier方法相關內容
按照以下路徑開啟原始碼檔案:
/home/yangqiao/codes/spark/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
先重點分析DecisionTree.scala檔案。
首先找到trainClassifier方法,程式碼如下:

 def trainClassifier(
      input: RDD[LabeledPoint],
      numClasses: Int,
      categoricalFeaturesInfo: Map[Int, Int],
      impurity: String,
      maxDepth: Int,
      maxBins: Int)
:
DecisionTreeModel = { val impurityType = Impurities.fromString(impurity) train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,categoricalFeaturesInfo) }

可以看到trainClassifier又呼叫了train方法,找到train方法進行檢視:

def train(
      input: RDD[LabeledPoint],
      algo: Algo,
      impurity: Impurity,
      maxDepth: Int,
      numClasses: Int,
      maxBins: Int,
      quantileCalculationStrategy: QuantileStrategy,
      categoricalFeaturesInfo: Map[Int,Int])
:
DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) }

首先介紹一下以上的引數含義:

 @param input Training dataset: RDD,標籤是{0, 1, ..., numClasses-1}.
 @param algo :classification(分類) 或者 regression(迴歸)
 @param impurity:資訊增益的計算方法,包括gini,entropy,varience。
 @param maxDepth:樹的最大深度,0
代表只有根節點,1代表1個根節點,兩個葉子節點。 @param numClasses:分類的數量,預設值是2。 @param maxBins :分類屬性的最大值。 @param quantileCalculationStrategy:計算分位數演算法 @param categoricalFeaturesInfo:儲存類別/屬性鍵值對 (n -> k),特性n有K個類別,下標分別是 0: {0, 1, ..., k-1}.

從這段程式可以看出,所有的引數先被封裝成strategy物件,將其作為引數初始化建立DecisionTree,接著呼叫run方法,首先看一下run方法,其程式碼是:

class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {

  strategy.assertValid()

  /**
   * Method to train a decision tree model over an RDD
   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
   * @return DecisionTreeModel that can be used for prediction
   */
  def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
    // Note: random seed will not be used since numTrees = 1.
    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
    val rfModel = rf.run(input)
    rfModel.trees(0)
  }
}

可以看出程式最終呼叫了RandomForest的方法,即對於spark MLlib,決策樹作為隨即森林的一個特例,即只有一棵樹,因此 rfModel.trees(0)方法中傳入的引數為0,即只有一棵樹。下面進入RandomForest原始檔,路徑是
/home/yangqiao/codes/spark/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
因為使用了RandomForest的run方法,因此找到run方法進行檢視:

def run(input: RDD[LabeledPoint]): RandomForestModel = {

    val timer = new TimeTracker()

    timer.start("total")

    timer.start("init")

    val retaggedInput = input.retag(classOf[LabeledPoint])
    val metadata =
      DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
    logDebug("algo = " + strategy.algo)
    logDebug("numTrees = " + numTrees)
    logDebug("seed = " + seed)
    logDebug("maxBins = " + metadata.maxBins)
    logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
    logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
    logDebug("subsamplingRate = " + strategy.subsamplingRate)

    // Find the splits and the corresponding bins (interval between the splits) using a sample
    // of the input data.
    timer.start("findSplitsBins")
    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
    timer.stop("findSplitsBins")
    logDebug("numBins: feature: number of bins")
    logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
        s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
      }.mkString("\n"))

    // Bin feature values (TreePoint representation).
    // Cache input RDD for speedup during multiple passes.
    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)

    val withReplacement = if (numTrees > 1) true else false

    val baggedInput
      = BaggedPoint.convertToBaggedRDD(treeInput,
          strategy.subsamplingRate, numTrees,
          withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)

    // depth of the decision tree
    val maxDepth = strategy.maxDepth
    require(maxDepth <= 30,
      s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")

    // Max memory usage for aggregates
    // TODO: Calculate memory usage more precisely.
    val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
    logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
    val maxMemoryPerNode = {
      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
        // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
        Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
          .take(metadata.numFeaturesPerNode).map(_._2))
      } else {
        None
      }
      RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    }
    require(maxMemoryPerNode <= maxMemoryUsage,
      s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
      " which is too small for the given features." +
      s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")

    timer.stop("init")

    /*
     * The main idea here is to perform group-wise training of the decision tree nodes thus
     * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
     * Each data sample is handled by a particular node (or it reaches a leaf and is not used
     * in lower levels).
     */

    // Create an RDD of node Id cache.
    // At first, all the rows belong to the root nodes (node Id == 1).
    val nodeIdCache = if (strategy.useNodeIdCache) {
      Some(NodeIdCache.init(
        data = baggedInput,
        numTrees = numTrees,
        checkpointInterval = strategy.checkpointInterval,
        initVal = 1))
    } else {
      None
    }

    // FIFO queue of nodes to train: (treeIndex, node)
    val nodeQueue = new mutable.Queue[(Int, Node)]()

    val rng = new scala.util.Random()
    rng.setSeed(seed)

    // Allocate and queue root nodes.
    val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

    while (nodeQueue.nonEmpty) {
      // Collect some nodes to split, and choose features for each node (if subsampling).
      // Each group of nodes may come from one or multiple trees, and at multiple levels.
      val (nodesForGroup, treeToNodeToIndexInfo) =
        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
      // Sanity check (should never occur):
      assert(nodesForGroup.size > 0,
        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")

      // Choose node splits, and enqueue new nodes as needed.
      timer.start("findBestSplits")
      DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
      timer.stop("findBestSplits")
    }

    baggedInput.unpersist()

    timer.stop("total")

    logInfo("Internal timing for DecisionTree:")
    logInfo(s"$timer")

    // Delete any remaining checkpoints used for node Id cache.
    if (nodeIdCache.nonEmpty) {
      try {
        nodeIdCache.get.deleteAllCheckpoints()
      } catch {
        case e:IOException =>
          logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
      }
    }

    val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
    new RandomForestModel(strategy.algo, trees)
  }

可以看到首先DecisionTreeMetadata類中的buildMetadata方法將輸入資料進行處理,因此應該先分析下buildMetadata都做了什麼。具體將在下一篇分析。