spark mllib原始碼分析之隨機森林(Random Forest)(三)
6. 隨機森林訓練
6.1. 資料結構
6.1.1. Node
樹中的每個節點是一個Node結構
class Node @Since("1.2.0") (
@Since("1.0.0") val id: Int,
@Since("1.0.0") var predict: Predict,
@Since("1.2.0") var impurity: Double,
@Since("1.0.0") var isLeaf: Boolean,
@Since("1.0.0") var split: Option[Split],
@Since("1.0.0") var leftNode: Option[Node],
@Since("1.0.0" ) var rightNode: Option[Node],
@Since("1.0.0") var stats: Option[InformationGainStats])
emptyNode,只初始化nodeIndex,其他都是預設值
def emptyNode(nodeIndex: Int): Node =
new Node(nodeIndex, new Predict(Double.MinValue),
-1.0, false, None, None, None, None)
根據node的id,計算孩子節點的id
* Return the index of the left child of this node.
*/
def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
/**
* Return the index of the right child of this node.
*/
def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
左孩子節點就是當前id * 2,右孩子是id * 2+1。
6.1.2. Entropy
6.1.2.1. Entropy
Entropy是個Object,裡面最重要的是calculate函式
/**
* :: DeveloperApi ::
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value, or 0 if totalCount = 0
*/
@Since("1.1.0")
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double = {
if (totalCount == 0) {
return 0
}
val numClasses = counts.length
var impurity = 0.0
var classIndex = 0
while (classIndex < numClasses) {
val classCount = counts(classIndex)
if (classCount != 0) {
val freq = classCount / totalCount
impurity -= freq * log2(freq)
}
classIndex += 1
}
impurity
}
熵的計算公式
因此這裡的入參count是各class的出現的次數,先計算出現概率,然後取log累加。
6.1.2.2. EntropyAggregator
class EntropyAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses)
只有一個成員變數class的個數,關鍵是update函式
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
if (label >= statsSize) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
if (label < 0) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s"but requires label is non-negative.")
}
allStats(offset + label.toInt) += instanceWeight
}
offset是特徵值偏移,加上label就是該class在allStats裡的位置,累加出現的次數
/**
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
}
擷取allStats中屬於該特徵的split的部分陣列,長度是statSize,也就是class數
6.1.2.3. EntropyCalculator
/**
* Calculate the impurity from the stored sufficient statistics.
*/
def calculate(): Double = Entropy.calculate(stats, stats.sum)
結合上面的函式可以看到,計算entropy的路徑是呼叫Entropy的getCalculator函式,裡面擷取allStats中屬於該split的部分,然後實際呼叫Entropy的calculate函式計算熵。
這裡還過載了prob函式,主要是返回label的概率,例如0的統計有3個,1的統計7個,則label 0的概率就是0.3.
6.1.3. DTStatsAggregator
這裡囉嗦下node分裂時需要怎樣統計,這與DTStatsAggregator的設計是相關的。以使用資訊熵為例,node分裂時,迭代每個特徵的每個split,這個split會把樣本集分成兩部分,要計算entropy,需要分別統計左/右部分class的分佈情況,然後計算概率,進而計算entropy,因此aggregator中statsSize等於numberclasses,同時allStats裡記錄了所有的統計值,實際這個統計值就是class的分佈情況
class DTStatsAggregator(
val metadata: DecisionTreeMetadata,
featureSubset: Option[Array[Int]]) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
*/
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}
/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
private val statsSize: Int = impurityAggregator.statsSize
/**
* Number of bins for each feature. This is indexed by the feature index.
*/
private val numBins: Array[Int] = {
if (featureSubset.isDefined) {
featureSubset.get.map(metadata.numBins(_))
} else {
metadata.numBins
}
}
/**
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
/**
* Total number of elements stored in this aggregator
*/
private val allStatsSize: Int = featureOffsets.last
/**
* Flat array of elements.
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features,
* the left child stats have binIndex in [0, numBins(featureIndex) / 2))
* and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
private val allStats: Array[Double] = new Array[Double](allStatsSize)
每個node有一個DTStatsAggregator,建構函式接受2個引數,metadata和node使用的特徵子集。其他的類成員
- impurityAggregator:目前支援Gini,Entropy和Variance,後面我們以Entropy為例,其他類似
- statsSize:每個bin需要的統計數,分類時等於numClasses,因為於每個class都需要單獨統計;迴歸等於3,分別存著特徵值個數,特徵值sum,特徵值平方和,為計算variance
- numBins:node所用特徵對應的numBins陣列元素
- featureOffsets:計算特徵在allStats中的index,與每個特徵的bin個數和statsSize有關,例如我們有3個特徵,其bins分別為3,2,2,statsSize為2,則第一個特徵需要的bin的個數是3 * 2=6,2 * 2=4,2 * 2=4,則featureOffsets為0,6,10,14,是從左到右的累計值
- allStatsSize:需要的桶的個數
- allStats:儲存統計值的桶
f0,f1,f2是3個特徵,f0有3個特徵值(其實是binIndex)0/1/2,f1有2個0/1,f2有2個0/1,每個特徵值都有statsSize個狀態桶,因此共14個,個數allStatsSize=14, 比如我們想在f1的v1的c1的index,就是從featureOffsets中取得f1的特徵偏移量featureOffsets(1)=6,v1的binIndex相當於是1,statsSize是2,其label是1,則桶的index=6+1*2+1=9,恰好是圖中f1v1的c1的桶的index
我們對其中的關鍵函式進行說明
/**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
//第一部分是特徵偏移
//binIndex相當於特徵內特徵值的偏移,每個特徵有statsSize個桶,因此兩者相加就是這個特徵值對應的桶
//例如Entropy的update函式,裡面再加上label.toInt就是這個label的桶
//從這裡特徵偏移的計算可以看出ordered特徵其特徵值最好是連續的,中間無間斷,並且必須從0開始
//當然如果有間斷,這裡相當於浪費部分空間
val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
* @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightFeatureOffsets]].
*/
def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
//偏移的計算同上,不過這裡特徵偏移是入參給出的,不需要再計算
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
6.2. 訓練初始化
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
構造了numTrees個Node,賦預設值emptyNode,這些node將作為每棵樹的root node,參與後面的訓練。將這些node與treeIndex封裝加入到佇列nodeQueue中,後面會將所有待split的node都加入到這個佇列中,依次split,直到所有node觸發截止條件,也就是後面的while迴圈中佇列為空了。
6.3. 選擇待分裂node
這部分邏輯在selectNodesToSplit中,主要是從nodeQueue中取出本輪需要分裂的node,並計算node的引數。
/**
* Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
* This tracks the memory usage for aggregates and stops adding nodes when too much memory
* will be needed; this allows an adaptive number of nodes since different nodes may require
* different amounts of memory (if featureSubsetStrategy is not "all").
*
* @param nodeQueue Queue of nodes to split.
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
*
* treeToNodeToIndexInfo holds indices selected features for each node:
* treeIndex --> (global) node index --> (node index in group, feature indices).
* The (global) node index is the index in the tree; the node index in group is the
* index in [0, numNodesInGroup) of the node in this group.
* The feature indices are None if not subsampling features.
*/
private[tree] def selectNodesToSplit(
nodeQueue: mutable.Queue[(Int, Node)],
maxMemoryUsage: Long,
metadata: DecisionTreeMetadata,
rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
// Collect some nodes to split:
// nodesForGroup(treeIndex) = nodes to split
val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
val mutableTreeToNodeToIndexInfo =
new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
var memUsage: Long = 0L
var numNodesInGroup = 0
while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
val (treeIndex, node) = nodeQueue.head
//用蓄水池抽樣(之前的文章有介紹)對node使用的特徵集抽樣
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
Some(SamplingUtils.reservoirSampleAndCount(Range(0,
metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
}
// Check if enough memory remains to add this node to the group.
val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
if (memUsage + nodeMemUsage <= maxMemoryUsage) {
nodeQueue.dequeue()
mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
mutableTreeToNodeToIndexInfo
.getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
= new NodeIndexInfo(numNodesInGroup, featureSubset)
}
numNodesInGroup += 1
memUsage += nodeMemUsage
}
// Convert mutable maps to immutable ones.
val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
(nodesForGroup, treeToNodeToIndexInfo)
}
程式碼比較簡單明確,受限於記憶體,將本次能夠處理的node從nodeQueue中取出,放入nodesForGroup和treeToNodeToIndexInfo中。
是否對特徵集進行抽樣的條件是metadata的 numFeatures是否等於numFeaturesPerNode,這兩個引數是metadata的入參,在buildMetadata時,根據featureSubsetStrateg確定,參見前文。
nodesForGroup是Map[Int, Array[Node]],其key是treeIndex,value是Node陣列,其中放著該tree本次要分裂的node。
treeToNodeToIndexInfo的型別是Map[Int, Map[Int, NodeIndexInfo]],key為treeIndex,value中Map的key是node.id,這個id來自Node初始化時的第一個引數,第一輪時node的id都是1。其value為NodeIndexInfo結構,
class NodeIndexInfo(
val nodeIndexInGroup: Int,
val featureSubset: Option[Array[Int]])
第一個成員是此node在本次node選擇的while迴圈中的index,稱為groupIndex,第二個成員是特徵子集。