1. 程式人生 > >spark mllib原始碼分析之隨機森林(Random Forest)(二)

spark mllib原始碼分析之隨機森林(Random Forest)(二)

4. 特徵處理

這部分主要在DecisionTree.scala的findSplitsBins函式,將所有特徵封裝成Split,然後裝箱Bin。首先對split和bin的結構進行說明

4.1. 資料結構

4.1.1. Split

class Split(
    @Since("1.0.0") feature: Int,
    @Since("1.0.0") threshold: Double,
    @Since("1.0.0") featureType: FeatureType,
    @Since("1.0.0") categories: List[Double])
  • feature:特徵id
  • threshold:閾值
  • featureType:連續特徵(Continuous)/離散特徵(Categorical)
  • categories:離散特徵值陣列,離散特徵使用。放著此split中所有特徵值

4.1.2. Bin

class Bin(
    lowSplit: Split, 
    highSplit: Split, 
    featureType: FeatureType, 
    category: Double)
  • lowSplit/highSplit:上下界
  • featureType:連續特徵(Continuous)/離散特徵(Categorical)
  • category:離散特徵的特徵值

4.2. 連續特徵處理

4.2.1. 抽樣

val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
      // Calculate the number of samples for approximate quantile calculation.
      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins
, 10000) val fraction = if (requiredSamples < metadata.numExamples) { requiredSamples.toDouble / metadata.numExamples } else { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()) } else { input.sparkContext.emptyRDD[LabeledPoint] }

首先篩選出連續特徵集,然後計算抽樣數量,抽樣比例,然後無放回樣本抽樣;如果沒有連續特徵,則為空RDD

4.2.2. 計算Split

metadata.quantileStrategy match {
      case Sort =>
        findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
      case MinMax =>
        throw new UnsupportedOperationException("minmax not supported yet.")
      case ApproxHist =>
        throw new UnsupportedOperationException("approximate histogram not supported yet.")
    }

分位點策略,這裡只實現了Sort這一種,前文有說明,下面的計算在findSplitsBinsBySorting函式中,入參是抽樣樣本集,metadata和連續特徵集(裡面是特徵id,從0開始,見LabelPoint的構造)

val continuousSplits = {
    // reduce the parallelism for split computations when there are less
    // continuous features than input partitions. this prevents tasks from
    // being spun up that will definitely do no work.
    val numPartitions = math.min(continuousFeatures.length,input.partitions.length)
    input.flatMap(point => continuousFeatures.map(idx =>  (idx,point.features(idx))))
         .groupByKey(numPartitions)
         .map { case (k, v) => findSplits(k, v) }
         .collectAsMap()
    }

特徵id為key,value是樣本對應的該特徵下的所有特徵值,傳給findSplits函式,其中又呼叫了findSplitsForContinuousFeature函式獲得連續特徵的Split,入參為樣本,metadata和特徵id

def findSplitsForContinuousFeature(
      featureSamples: Array[Double], 
      metadata: DecisionTreeMetadata,
      featureIndex: Int): Array[Double] = {
    require(metadata.isContinuous(featureIndex),
      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

    val splits = {
    //連續特徵的split是numBins-1
      val numSplits = metadata.numSplits(featureIndex)
    //統計所有特徵值其出現的次數
      // get count for each distinct value
      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
        m + ((x, m.getOrElse(x, 0) + 1))
      }
      //按特徵值排序
      // sort distinct values
      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray

      // if possible splits is not enough or just enough, just return all possible splits
      val possibleSplits = valueCounts.length
      if (possibleSplits <= numSplits) {
        valueCounts.map(_._1)
      } else {
      //等頻離散化
        // stride between splits
        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
        logDebug("stride = " + stride)

        // iterate `valueCount` to find splits
        val splitsBuilder = Array.newBuilder[Double]
        var index = 1
        // currentCount: sum of counts of values that have been visited
        var currentCount = valueCounts(0)._2
        // targetCount: target value for `currentCount`.
        // If `currentCount` is closest value to `targetCount`,
        // then current value is a split threshold.
        // After finding a split threshold, `targetCount` is added by stride.
        var targetCount = stride
        while (index < valueCounts.length) {
          val previousCount = currentCount
          currentCount += valueCounts(index)._2
          val previousGap = math.abs(previousCount - targetCount)
          val currentGap = math.abs(currentCount - targetCount)
          // If adding count of current value to currentCount
          // makes the gap between currentCount and targetCount smaller,
          // previous value is a split threshold.
          //每次步進targetCount個樣本,取上一個特徵值與下一個特徵值gap較小的
          if (previousGap < currentGap) {
            splitsBuilder += valueCounts(index - 1)._1
            targetCount += stride
          }
          index += 1
        }

        splitsBuilder.result()
      }
    }

    // TODO: Do not fail; just ignore the useless feature.
    assert(splits.length > 0,
      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
        "  Please remove this feature and then try again.")

    // the split metadata must be updated on the driver

    splits
  }

在構造split的過程中,如果統計到的值的個數possibleSplits 還不如你設定的numSplits多,那麼所有的值都作為分割點;否則,用等頻分隔法,首先計算分隔步長stride,然後再迴圈中每次累加到targetCount中,作為理想分割點,但是理想分割點可能會包含的特徵值過多,想取一個裡理想分割點儘量近的特徵值,例如,理想分割點是100,落在特徵值fc裡,但是當前特徵值裡面有30個樣本,而前一個特徵值fp只有5個樣本,因此我們如果取fc作為split,則當前區間實際多25個樣本,如果取fp,則少5個樣本,顯然取fp更為合理。
具體到程式碼實現,在if判斷裡步進stride個樣本,累加在targetCount中。while迴圈逐次把每個特徵值的個數加到currentCount裡,計算前一次previousCount和這次currentCount到targetCount的距離,有3種情況,一種是pre和cur都在target左邊,肯定是cur小,繼續迴圈,進入第二種情況;第二種一左一右,如果pre小,肯定是pre是最好的分割點,如果cur還是小,繼續迴圈步進,進入第三種情況;第三種就是都在右邊,顯然是pre小。因此if的判斷條件pre<cur,只要滿足肯定就是split。整體下來的效果就能找到離target最近的一個特徵值。
findSplits函式使用本函式得到的離散化點作為threshold,構造Split

val splits = {
    val featureSplits = findSplitsForContinuousFeature(
          featureSamples.toArray,
          metadata,
          featureIndex)
    logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")

    featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
}

這樣就得到了連續特徵所有的Split
4.2.3. 計算bin
得到splits後,即可類似滑窗得到bin的上下界,構造bins

val bins = {
    val lowSplit = new DummyLowSplit(featureIndex, Continuous)
    val highSplit = new DummyHighSplit(featureIndex, Continuous)

    // tack the dummy splits on either side of the computed splits
    val allSplits = lowSplit +: splits.toSeq :+ highSplit

    // slide across the split points pairwise to allocate the bins
    allSplits.sliding(2).map {
         case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
    }.toArray
}

在計算splits的時候,個數是bin的個數減1,這裡加上第一個DummyLowSplit(threshold是Double.MinValue),和最後一個DummyHighSplit(threshold是Double.MaxValue)構造的bin,恰好個數是numBins中的個數

4.3. 離散特徵

bin的主要作用其實就是用來做連續特徵離散化,離散特徵是用不著的。
對有序離散特徵而言,其split直接用特徵值表徵,因此這裡的splits和bins都是空的Array。
對於無序離散特徵而言,其split是特徵值的組合,不是簡單的上下界比較關係,bin是空Array,而split需要計算。

4.3.1. split

// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
    val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
    new Split(i, Double.MinValue, Categorical, categories)
}

featureArity來自引數categoricalFeaturesInfo中設定的離散特徵的特徵值數。
metadata.numSplits是吧numBins中的數量/2,相當於返回了2^(M-1)-1,M是特徵值數。
呼叫extractMultiClassCategories函式,入參是1到2^(M-1)和特徵數M。

/**
   * Nested method to extract list of eligible categories given an index. It extracts the
   * position of ones in a binary representation of the input. If binary
   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
   */
def extractMultiClassCategories(
     input: Int,
     maxFeatureValue: Int): List[Double] = {
    var categories = List[Double]()
    var j = 0
    var bitShiftedInput = input
    while (j < maxFeatureValue) {
      if (bitShiftedInput % 2 != 0) {
        // updating the list of categories.
        categories = j.toDouble :: categories
      }
      // Right shift by one
      bitShiftedInput = bitShiftedInput >> 1
      j += 1
    }
    categories
}

如註釋所述,這個函式返回給定的input的二進位制表示中1的index,這裡實際返回的是特徵的組合,之前文章介紹過的《組合數》

5. 樣本處理

將輸入樣本LabelPoint與上述特徵進一步封裝,方便後面進行分割槽統計。

5.1. TreePoint

構造TreePoint的過程,是一系列函式的呼叫鏈,我們逐層分析。

val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)

RandomForest.scala中將輸入轉化成TreePoint的rdd,呼叫convertToTreeRDD函式

def convertToTreeRDD(
    input: RDD[LabeledPoint],
    bins: Array[Array[Bin]],
    metadata: DecisionTreeMetadata): RDD[TreePoint] = {
    // Construct arrays for featureArity for efficiency in the inner loop.
    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
    var featureIndex = 0
    while (featureIndex < metadata.numFeatures) {
      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
      featureIndex += 1
    }
    input.map { x =>
      TreePoint.labeledPointToTreePoint(x, bins, featureArity)
    }
  }

convertToTreeRDD函式的入參input是所有樣本,bins是二維陣列,第一維是特徵,第二維是特徵的Bin陣列。函式首先計算每個特徵的特徵數量,放在featureArity中,如果是連續特徵,設為0。對每個樣本呼叫labeledPointToTreePoint函式,構造TreePoint。

private def labeledPointToTreePoint(
      labeledPoint: LabeledPoint,
      bins: Array[Array[Bin]],
      featureArity: Array[Int]): TreePoint = {
    val numFeatures = labeledPoint.features.size
    val arr = new Array[Int](numFeatures)
    var featureIndex = 0
    while (featureIndex < numFeatures) {
      arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
        bins)
      featureIndex += 1
    }
    new TreePoint(labeledPoint.label, arr)
  }

labeledPointToTreePoint計算每個樣本的所有特徵對應的特徵值屬於哪個bin,放在在arr陣列中;如果是連續特徵,存放的實際是binIndex,或者說是第幾個bin;如果是離散特徵,直接featureValue.toInt,這其實暗示著,對有序離散值,其編碼只能是[0,featureArity - 1],閉區間,其後的部分邏輯也依賴於這個假設。這部分是在findBin函式中完成的,這裡不再贅述。
我們在這裡把TreePoint的成員再羅列一下,方便查閱

class TreePoint(val label: Double, val binnedFeatures: Array[Int])

這裡是把每個樣本從LabelPoint轉換成TreePoint,label就是樣本label,binnedFeatures就是上述的arr陣列。

5.2. BaggedPoint

同理構造BaggedPoint的過程,也是一系列函式的呼叫鏈,我們逐層分析。

val withReplacement = if (numTrees > 1) true else false
val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
          strategy.subsamplingRate, numTrees,
          withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)

這裡同時對樣本進行了抽樣,如果樹個數大於1,就有放回抽樣,否則無放回抽樣,呼叫convertToTreeRDD函式將TreePoint轉化成BaggedPoint的rdd

/**
   * Convert an input dataset into its BaggedPoint representation,
   * choosing subsamplingRate counts for each instance.
   * Each subsamplingRate has the same number of instances as the original dataset,
   * and is created by subsampling without replacement.
   * @param input Input dataset.
   * @param subsamplingRate Fraction of the training data used for learning decision tree.
   * @param numSubsamples Number of subsamples of this RDD to take.
   * @param withReplacement Sampling with/without replacement.
   * @param seed Random seed.
   * @return BaggedPoint dataset representation.
   */
  def convertToBaggedRDD[Datum] (
      input: RDD[Datum],
      subsamplingRate: Double,
      numSubsamples: Int,
      withReplacement: Boolean,
      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
    if (withReplacement) {
      convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
    } else {
      if (numSubsamples == 1 && subsamplingRate == 1.0) {
        convertToBaggedRDDWithoutSampling(input)
      } else {
        convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
      }
    }
  }

根據有放回還是無放回,或者不抽樣分別呼叫相應函式。無放回抽樣

def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
      input: RDD[Datum],
      subsamplingRate: Double,
      numSubsamples: Int,
      seed: Long): RDD[BaggedPoint[Datum]] = {
    //對每個partition獨立抽樣
    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
      val rng = new XORShiftRandom
      rng.setSeed(seed + partitionIndex + 1)
      instances.map { instance =>
      //對每條樣本進行numSubsamples(實際是樹的個數)次抽樣,
      //一次將本條樣本在所有樹中是否會被抽取都獲得,犧牲空間減少訪問資料次數
        val subsampleWeights = new Array[Double](numSubsamples)
        var subsampleIndex = 0
        while (subsampleIndex < numSubsamples) {
          val x = rng.nextDouble()
          //無放回抽樣,只需要決定本樣本是否被抽取,被抽取就是1,沒有就是0
          subsampleWeights(subsampleIndex) = {
            if (x < subsamplingRate) 1.0 else 0.0
          }
          subsampleIndex += 1
        }
        new BaggedPoint(instance, subsampleWeights)
      }
    }
  }

有放回抽樣

def convertToBaggedRDDSamplingWithReplacement[Datum] (
      input: RDD[Datum],
      subsample: Double,
      numSubsamples: Int,
      seed: Long): RDD[BaggedPoint[Datum]] = {
    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
      val poisson = new PoissonDistribution(subsample)
      poisson.reseedRandomGenerator(seed + partitionIndex + 1)
      instances.map { instance =>
        val subsampleWeights = new Array[Double](numSubsamples)
        var subsampleIndex = 0
        while (subsampleIndex < numSubsamples) {
        //與無放回抽樣對比,這裡用泊松抽樣返回的是樣本被抽取的次數,
        //可能大於1,而無放回是0/1,也可認為是被抽取的次數
          subsampleWeights(subsampleIndex) = poisson.sample()
          subsampleIndex += 1
        }
        new BaggedPoint(instance, subsampleWeights)
      }
    }
  }

不抽樣,或者說抽樣率為1

def convertToBaggedRDDWithoutSampling[Datum] (
      input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
    input.map(datum => new BaggedPoint(datum, Array(1.0)))
  }

這裡再囉嗦的羅列下BaggedPoint

class BaggedPoint[Datum](
    val datum: Datum, 
    val subsampleWeights: Array[Double])

datum是TreePoint,subsampleWeights是陣列,維數等於numberTrees,每個值是樣本在每棵樹中被抽取的次數

至此,Random Forest的初始化工作已經完成

timer.stop("init")

相關推薦

spark mllib原始碼分析隨機森林(Random Forest)

4. 特徵處理 這部分主要在DecisionTree.scala的findSplitsBins函式,將所有特徵封裝成Split,然後裝箱Bin。首先對split和bin的結構進行說明 4.1. 資料結構 4.1.1. Split cl

spark mllib原始碼分析隨機森林(Random Forest)

6. 隨機森林訓練 6.1. 資料結構 6.1.1. Node 樹中的每個節點是一個Node結構 class Node @Since("1.2.0") ( @Since("1.0.0") val id: Int, @S

spark mllib原始碼分析分類邏輯迴歸evaluation

在邏輯迴歸分類中,我們評價分類器好壞的主要指標有精準率(precision),召回率(recall),F-measure,AUC等,其中最常用的是AUC,它可以綜合評價分類器效能,其他的指標主要偏重一些方面。我們介紹下spark中實現的這些評價指標,便於使用sp

spark mllib原始碼分析DecisionTree與GBDT

我們在前面的文章講過,在spark的實現中,樹模型的依賴鏈是GBDT-> Decision Tree-> Random Forest,前面介紹了最基礎的Random Forest的實現,在此基礎上我們介紹Decision Tree和GBDT的實現

spark mllib原始碼分析L-BFGS

1. 使用 spark給出的example中涉及到LBFGS有兩個,分別是LBFGSExample.scala和LogisticRegressionWithLBFGSExample.scala,第一個是直接使用LBFGS直接訓練,需要指定一系列優化引數,優

spark mllib原始碼分析邏輯迴歸彈性網路ElasticNet

spark在ml包中將邏輯迴歸封裝了下,同時在演算法中引入了L1和L2正則化,通過elasticNetParam來調節兩種正則化的係數,同時根據選擇的正則化,決定使用L-BFGS還是OWLQN優化,是謂Elastic Net。 1. 輔助類 我們首先介紹

決策樹模型組合隨機森林與GBDT

get 9.png 生成 代碼 margin ast decision 損失函數 固定 版權聲明: 本文由LeftNotEasy發布於http://leftnoteasy.cnblogs.com, 本文可以被全部的轉載或者部分使用,但請註明出處,如果有問題,請

Memcached原始碼分析增刪改查操作5

文章列表: 《Memcached原始碼分析 - Memcached原始碼分析之總結篇(8)》 前言 在看Memcached的增刪改查操作前,我們先來看一下process_command方法。Memcached解析命令之後,就通過process_comman

android原始碼分析View的事件分發

1、View的繼承關係圖 View的繼承關係圖如下: 其中最重要的子類為ViewGroup,View是所有UI元件的基類,而ViewGroup是容納這些元件的容器,同時它也是繼承於View類。而UI元件的繼承關係如上圖,比較常用的元件類用紅色字型標出

從壹開始微服務 [ DDD ] 十一 ║ 基於原始碼分析,命令分發的過程

緣起 哈嘍小夥伴週三好,老張又來啦,DDD領域驅動設計的第二個D也快說完了,下一個系列我也在考慮之中,是 Id4 還是 Dockers 還沒有想好,甚至昨天我還想,下一步是不是可以寫一個簡單的Angular 入門教程,本來是想來個前後端分離的教學視訊的,簡單試了試,發現自己的聲音不好聽,真心不好聽那種,就作

spring原始碼學習路---IOC實現原理

上一章我們已經初步認識了BeanFactory和BeanDefinition,一個是IOC的核心工廠介面,一個是IOC的bean定義介面,上章提到說我們無法讓BeanFactory持有一個Map package org.springframework.beans.factory.supp

【spring原始碼分析】IOC容器初始化

前言:在【spring原始碼分析】IOC容器初始化(一)中已經分析了匯入bean階段,本篇接著分析bean解析階段。 1.解析bean程式呼叫鏈 同樣,先給出解析bean的程式呼叫鏈: 根據程式呼叫鏈,整理出在解析bean過程中主要涉及的類和相關方法。 2.解析bean原始碼分

Java原始碼分析——java.util工具包解析——HashSet、TreeSet、LinkedHashSet類解析

    Set,即集合,與數學上的定義一樣,集合具有三個特點: 無序性:一個集合中,每個元素的地位都是相同的,元素之間是無序的。 互異性:一個集合中,任何兩個元素都認為是不相同的,即每個元素只能出現一次。 確定性:給定一個集

Spring原始碼解讀Spring MVC HandlerMapping元件

一、HandlerMapping HandlerMapping作用是根據request找到相應的處理器Handler和Interceptors,並將Handler和Interceptors封裝成HandlerExecutionChain 物件返回。Handler

Android架構分析Android訊息處理機制

作者:劉昊昱  Android版本:4.4.2 在上一篇文章中我們看了一個使用Handler處理Message訊息的例子,本文我們來分析一下其背後隱藏的Android訊息處理機制。 我們可能比較熟悉Windows作業系統的訊息處理模型: while(GetMessage

springMVC原始碼分析--HandlerInterceptor攔截器呼叫過程

在上一篇部落格springMVC原始碼分析--HandlerInterceptor攔截器(一)中我們介紹了HandlerInterceptor攔截器相關的內容,瞭解到了HandlerInterceptor提供的三個介面方法:(1)preHandle: 在執行controlle

Spark core原始碼分析spark叢集的啟動

2.2 Worker的啟動 org.apache.spark.deploy.worker 1 從Worker的伴生物件的main方法進入 在main方法中首先是得到一個SparkConf例項conf,然後將conf和啟動Worker傳入的引數封裝得到Wor

Spark SQL 原始碼分析Physical Plan 到 RDD的具體實現

  我們都知道一段sql,真正的執行是當你呼叫它的collect()方法才會執行Spark Job,最後計算得到RDD。 lazy val toRdd: RDD[Row] = executedPlan.execute()  Spark Plan基本包含4種操作型別,即Bas

Spark MLlib原始碼分析—Word2Vec原始碼詳解

以下程式碼是我依據SparkMLlib(版本1.6)中Word2Vec原始碼改寫而來,基本算是照搬。此版Word2Vec是基於Hierarchical Softmax的Skip-gram模型的實現。 在決定讀懂原始碼前,博主建議讀者先看一下《Word2Vec_

Spark MLlib原始碼解讀樸素貝葉斯分類器,NaiveBayes

Spark MLlib 樸素貝葉斯NaiveBayes 原始碼分析 基本原理介紹 首先是基本的條件概率求解的公式。 P(A|B)=P(AB)P(B) 在現實生活中,我們經常會碰到已知一個條件概率,求得兩個時間交換後的概率的問題。也就是在已知P(A