1. 程式人生 > >spark mllib原始碼分析之隨機森林(Random Forest)(三)

spark mllib原始碼分析之隨機森林(Random Forest)(三)

6. 隨機森林訓練

6.1. 資料結構

6.1.1. Node

樹中的每個節點是一個Node結構

class Node @Since("1.2.0") (
    @Since("1.0.0") val id: Int,
    @Since("1.0.0") var predict: Predict,
    @Since("1.2.0") var impurity: Double,
    @Since("1.0.0") var isLeaf: Boolean,
    @Since("1.0.0") var split: Option[Split],
    @Since("1.0.0") var leftNode: Option[Node],
    @Since("1.0.0"
)
var rightNode: Option[Node], @Since("1.0.0") var stats: Option[InformationGainStats])

emptyNode,只初始化nodeIndex,其他都是預設值

def emptyNode(nodeIndex: Int): Node = 
    new Node(nodeIndex, new Predict(Double.MinValue),
    -1.0, false, None, None, None, None)

根據node的id,計算孩子節點的id

   * Return the index of the left child of this node.
   */
  def
leftChildIndex(nodeIndex: Int):
Int = nodeIndex << 1 /** * Return the index of the right child of this node. */ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1

左孩子節點就是當前id * 2,右孩子是id * 2+1。

這裡寫圖片描述

6.1.2. Entropy

6.1.2.1. Entropy

Entropy是個Object,裡面最重要的是calculate函式

/**
   * :: DeveloperApi ::
   * information calculation for multiclass classification
   * @param counts Array[Double] with counts for each label
   * @param totalCount sum of counts for all labels
   * @return information value, or 0 if totalCount = 0
   */
  @Since("1.1.0")
  @DeveloperApi
  override def calculate(counts: Array[Double], totalCount: Double): Double = {
    if (totalCount == 0) {
      return 0
    }
    val numClasses = counts.length
    var impurity = 0.0
    var classIndex = 0
    while (classIndex < numClasses) {
      val classCount = counts(classIndex)
      if (classCount != 0) {
        val freq = classCount / totalCount
        impurity -= freq * log2(freq)
      }
      classIndex += 1
    }
    impurity
  }

熵的計算公式

H=E[logpi]=i=1npilogpi
因此這裡的入參count是各class的出現的次數,先計算出現概率,然後取log累加。
6.1.2.2. EntropyAggregator
class EntropyAggregator(numClasses: Int)
  extends ImpurityAggregator(numClasses)

只有一個成員變數class的個數,關鍵是update函式

/**
   * Update stats for one (node, feature, bin) with the given label.
   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
   * @param offset    Start index of stats for this (node, feature, bin).
   */
  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
    if (label >= statsSize) {
      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
        s" but requires label < numClasses (= $statsSize).")
    }
    if (label < 0) {
      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
        s"but requires label is non-negative.")
    }
    allStats(offset + label.toInt) += instanceWeight
  }

offset是特徵值偏移,加上label就是該class在allStats裡的位置,累加出現的次數

/**
   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
   * @param offset    Start index of stats for this (node, feature, bin).
   */
  def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
    new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
  }

擷取allStats中屬於該特徵的split的部分陣列,長度是statSize,也就是class數

6.1.2.3. EntropyCalculator
/**
   * Calculate the impurity from the stored sufficient statistics.
   */
  def calculate(): Double = Entropy.calculate(stats, stats.sum)

結合上面的函式可以看到,計算entropy的路徑是呼叫Entropy的getCalculator函式,裡面擷取allStats中屬於該split的部分,然後實際呼叫Entropy的calculate函式計算熵。
這裡還過載了prob函式,主要是返回label的概率,例如0的統計有3個,1的統計7個,則label 0的概率就是0.3.

6.1.3. DTStatsAggregator

這裡囉嗦下node分裂時需要怎樣統計,這與DTStatsAggregator的設計是相關的。以使用資訊熵為例,node分裂時,迭代每個特徵的每個split,這個split會把樣本集分成兩部分,要計算entropy,需要分別統計左/右部分class的分佈情況,然後計算概率,進而計算entropy,因此aggregator中statsSize等於numberclasses,同時allStats裡記錄了所有的統計值,實際這個統計值就是class的分佈情況

class DTStatsAggregator(
    val metadata: DecisionTreeMetadata,
    featureSubset: Option[Array[Int]]) extends Serializable {

  /**
   * [[ImpurityAggregator]] instance specifying the impurity type.
   */
  val impurityAggregator: ImpurityAggregator = metadata.impurity match {
    case Gini => new GiniAggregator(metadata.numClasses)
    case Entropy => new EntropyAggregator(metadata.numClasses)
    case Variance => new VarianceAggregator()
    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
  }

  /**
   * Number of elements (Double values) used for the sufficient statistics of each bin.
   */
  private val statsSize: Int = impurityAggregator.statsSize

  /**
   * Number of bins for each feature.  This is indexed by the feature index.
   */
  private val numBins: Array[Int] = {
    if (featureSubset.isDefined) {
      featureSubset.get.map(metadata.numBins(_))
    } else {
      metadata.numBins
    }
  }

  /**
   * Offset for each feature for calculating indices into the [[allStats]] array.
   */
  private val featureOffsets: Array[Int] = {
    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
  }

  /**
   * Total number of elements stored in this aggregator
   */
  private val allStatsSize: Int = featureOffsets.last

  /**
   * Flat array of elements.
   * Index for start of stats for a (feature, bin) is:
   *   index = featureOffsets(featureIndex) + binIndex * statsSize
   * Note: For unordered features,
   *       the left child stats have binIndex in [0, numBins(featureIndex) / 2))
   *       and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
   */
  private val allStats: Array[Double] = new Array[Double](allStatsSize)

每個node有一個DTStatsAggregator,建構函式接受2個引數,metadata和node使用的特徵子集。其他的類成員
- impurityAggregator:目前支援Gini,Entropy和Variance,後面我們以Entropy為例,其他類似
- statsSize:每個bin需要的統計數,分類時等於numClasses,因為於每個class都需要單獨統計;迴歸等於3,分別存著特徵值個數,特徵值sum,特徵值平方和,為計算variance
- numBins:node所用特徵對應的numBins陣列元素
- featureOffsets:計算特徵在allStats中的index,與每個特徵的bin個數和statsSize有關,例如我們有3個特徵,其bins分別為3,2,2,statsSize為2,則第一個特徵需要的bin的個數是3 * 2=6,2 * 2=4,2 * 2=4,則featureOffsets為0,6,10,14,是從左到右的累計值
- allStatsSize:需要的桶的個數
- allStats:儲存統計值的桶
這裡寫圖片描述
f0,f1,f2是3個特徵,f0有3個特徵值(其實是binIndex)0/1/2,f1有2個0/1,f2有2個0/1,每個特徵值都有statsSize個狀態桶,因此共14個,個數allStatsSize=14, 比如我們想在f1的v1的c1的index,就是從featureOffsets中取得f1的特徵偏移量featureOffsets(1)=6,v1的binIndex相當於是1,statsSize是2,其label是1,則桶的index=6+1*2+1=9,恰好是圖中f1v1的c1的桶的index

我們對其中的關鍵函式進行說明

/**
   * Update the stats for a given (feature, bin) for ordered features, using the given label.
   */
  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
  //第一部分是特徵偏移
  //binIndex相當於特徵內特徵值的偏移,每個特徵有statsSize個桶,因此兩者相加就是這個特徵值對應的桶
  //例如Entropy的update函式,裡面再加上label.toInt就是這個label的桶
  //從這裡特徵偏移的計算可以看出ordered特徵其特徵值最好是連續的,中間無間斷,並且必須從0開始
  //當然如果有間斷,這裡相當於浪費部分空間
    val i = featureOffsets(featureIndex) + binIndex * statsSize
    impurityAggregator.update(allStats, i, label, instanceWeight)
  }
  /**
   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
   * @param featureOffset  For ordered features, this is a pre-computed (node, feature) offset
   *                           from [[getFeatureOffset]].
   *                           For unordered features, this is a pre-computed
   *                           (node, feature, left/right child) offset from
   *                           [[getLeftRightFeatureOffsets]].
   */
  def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
  //偏移的計算同上,不過這裡特徵偏移是入參給出的,不需要再計算
    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
  }

6.2. 訓練初始化

// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()

val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

構造了numTrees個Node,賦預設值emptyNode,這些node將作為每棵樹的root node,參與後面的訓練。將這些node與treeIndex封裝加入到佇列nodeQueue中,後面會將所有待split的node都加入到這個佇列中,依次split,直到所有node觸發截止條件,也就是後面的while迴圈中佇列為空了。

6.3. 選擇待分裂node

這部分邏輯在selectNodesToSplit中,主要是從nodeQueue中取出本輪需要分裂的node,並計算node的引數。

/**
   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
   * This tracks the memory usage for aggregates and stops adding nodes when too much memory
   * will be needed; this allows an adaptive number of nodes since different nodes may require
   * different amounts of memory (if featureSubsetStrategy is not "all").
   *
   * @param nodeQueue  Queue of nodes to split.
   * @param maxMemoryUsage  Bound on size of aggregate statistics.
   * @return  (nodesForGroup, treeToNodeToIndexInfo).
   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
   *
   *          treeToNodeToIndexInfo holds indices selected features for each node:
   *            treeIndex --> (global) node index --> (node index in group, feature indices).
   *          The (global) node index is the index in the tree; the node index in group is the
   *           index in [0, numNodesInGroup) of the node in this group.
   *          The feature indices are None if not subsampling features.
   */
  private[tree] def selectNodesToSplit(
      nodeQueue: mutable.Queue[(Int, Node)],
      maxMemoryUsage: Long,
      metadata: DecisionTreeMetadata,
      rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
    // Collect some nodes to split:
    //  nodesForGroup(treeIndex) = nodes to split
    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
    val mutableTreeToNodeToIndexInfo =
      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
    var memUsage: Long = 0L
    var numNodesInGroup = 0
    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
      val (treeIndex, node) = nodeQueue.head
      //用蓄水池抽樣(之前的文章有介紹)對node使用的特徵集抽樣
      // Choose subset of features for node (if subsampling).
      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
        Some(SamplingUtils.reservoirSampleAndCount(Range(0,
          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
      } else {
        None
      }
      // Check if enough memory remains to add this node to the group.
      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
      if (memUsage + nodeMemUsage <= maxMemoryUsage) {
        nodeQueue.dequeue()
        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
        mutableTreeToNodeToIndexInfo
          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
          = new NodeIndexInfo(numNodesInGroup, featureSubset)
      }
      numNodesInGroup += 1
      memUsage += nodeMemUsage
    }
    // Convert mutable maps to immutable ones.
    val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
    (nodesForGroup, treeToNodeToIndexInfo)
  }

程式碼比較簡單明確,受限於記憶體,將本次能夠處理的node從nodeQueue中取出,放入nodesForGroup和treeToNodeToIndexInfo中。
是否對特徵集進行抽樣的條件是metadata的 numFeatures是否等於numFeaturesPerNode,這兩個引數是metadata的入參,在buildMetadata時,根據featureSubsetStrateg確定,參見前文。
nodesForGroup是Map[Int, Array[Node]],其key是treeIndex,value是Node陣列,其中放著該tree本次要分裂的node。
treeToNodeToIndexInfo的型別是Map[Int, Map[Int, NodeIndexInfo]],key為treeIndex,value中Map的key是node.id,這個id來自Node初始化時的第一個引數,第一輪時node的id都是1。其value為NodeIndexInfo結構,

class NodeIndexInfo(
      val nodeIndexInGroup: Int,
      val featureSubset: Option[Array[Int]])

第一個成員是此node在本次node選擇的while迴圈中的index,稱為groupIndex,第二個成員是特徵子集。