1. 程式人生 > >spark mllib原始碼分析之DecisionTree與GBDT

spark mllib原始碼分析之DecisionTree與GBDT

我們在前面的文章講過,在spark的實現中,樹模型的依賴鏈是GBDT-> Decision Tree-> Random Forest,前面介紹了最基礎的Random Forest的實現,在此基礎上我們介紹Decision Tree和GBDT的實現。

1. Decision Tree

1.1. DT的使用

官方給出的demo

// Train a DecisionTree model.
    //  Empty categoricalFeaturesInfo indicates all features are continuous.
    val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]() val impurity = "gini" val maxDepth = 5 val maxBins = 32 val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins)

其入參除了不需要指定樹個數,其他引數與隨機森林類似,不再贅述

1.2 實現

主要的邏輯在DecisionTree.scala的run函式中

  /**
   * Method to train a decision tree model over an RDD
   * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
   * @return DecisionTreeModel that can be used for prediction
   */
  @Since("1.2.0")
  def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
    // Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) val rfModel = rf.run(input) rfModel.trees(0) }

其實就是Random Forest 1棵樹的情形,同時特徵不再抽樣。

2. Gradient Boosting Decision Tree

2.1. 演算法簡介

簡稱GBDT,中文譯作梯度提升決策樹,估計沒幾個人聽過。這裡貼幾張之前介紹GBDT的PPT,簡單回顧起演算法原理,其中內容來自wikipedia和”From RankNet to LambdaRank to LambdaMAR An Overview”這篇文章。

2.1.1. 演算法原理

這裡寫圖片描述
在這個演算法裡面,並沒有限定使用決策樹,如果使用決策樹,對應裡面的h應該是樹結構,我們以決策樹說明
1. 使用原始樣本直接訓練一棵樹
迴圈訓練
2. 計算偽殘差,實際是梯度
3. 將2中的偽殘差作為樣本的label去訓練決策樹
4. 這裡是用最優化方法計算葉子節點的輸出,而spark中直接使用的均值
5. 計算當輪模型的輸出,方法是上一輪的輸出加上本輪的預測值
6. 迴圈結束後,輸出模型

2.1.2. 以二分類為例

這裡寫圖片描述
這裡寫圖片描述
這裡寫圖片描述

2.2. GBDT使用

官方demo

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

首先初始化訓練引數boostingStrategy,然後設定其迭代次數,分類樹,樹的最大深度,離散特徵及其特徵值數,我們看下預設的引數都有哪些

/**
   * Returns default configuration for the boosting algorithm
   * @param algo Learning goal.  Supported:
   *             [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
   *             [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
   * @return Configuration for boosting algorithm
   */
  @Since("1.3.0")
  def defaultParams(algo: Algo): BoostingStrategy = {
    val treeStrategy = Strategy.defaultStrategy(algo)
    treeStrategy.maxDepth = 3
    algo match {
      case Algo.Classification =>
        treeStrategy.numClasses = 2
        new BoostingStrategy(treeStrategy, LogLoss)
      case Algo.Regression =>
        new BoostingStrategy(treeStrategy, SquaredError)
      case _ =>
        throw new IllegalArgumentException(s"$algo is not supported by boosting.")
    }
  }

預設樹的最大深度為3,如果是分類,為二分類,使用LogLoss;如果是迴歸,使用SquareError,均方誤差。然後使用Strategy的預設引數

  /**
   * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
   * @param algo Algo.Classification or Algo.Regression
   */
  @Since("1.3.0")
  def defaultStrategy(algo: Algo): Strategy = algo match {
    case Algo.Classification =>
      new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
        numClasses = 2)
    case Algo.Regression =>
      new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
        numClasses = 0)
  }

Strategy的預設引數也比較簡單,其意義參見之前的文章。

2.3. GBDT實現

其實現開始於GradientBoostedTrees.scala的run函式

  /**
   * Method to train a gradient boosting model
   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
   * @return a gradient boosted trees model that can be used for prediction
   */
  @Since("1.2.0")
  def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
    val algo = boostingStrategy.treeStrategy.algo
    algo match {
      case Regression =>
        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
      case Classification =>
        // Map labels to -1, +1 so binary classification can be treated as regression.
        val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
        GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
      case _ =>
        throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
    }
  }

從其註釋可以看到,spark GBDT只實現了二分類,並且二分類的class必須是0/1,其把0/1轉化成-1/+1的label,然後按回歸處理。

2.3.2. 資料結構

2.3.2.1. LogLoss

在第二頁PPT中我們給出了loss,spark使用的loss是σ=1,log前增加了係數2的情況

L(y,FN)=2log(1+e2yFN)
對應梯度變成 g=4y/(1+e2yFm1(x))
其中m-1指的是在第m次迭代中,使用的是m-1次的預測值。注意到我們的PPT的第四頁的γ,其實是葉子節點的預測值,是通過最優化得到的,而spark這裡使用的是Random Forest的程式碼,其impurity選擇的是variance,因此預測值是均值。
  @Since("1.2.0")
  override def gradient(prediction: Double, label: Double): Double = {
    - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
  }

  override private[mllib] def computeError(prediction: Double, label: Double): Double = {
  //loss
    val margin = 2.0 * label * prediction
    // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
    2.0 * MLUtils.log1pExp(-margin)
  }

SquaredError比較簡單,這裡不再囉嗦了。

2.3.1. init

將傳入的引數轉成訓練時的引數,cache predError和validatePredError,並且按treeStrategy.getCheckpointInterval(default 10)建立checkpoint。這裡程式碼比較簡單,不再贅述。

2.3.2. build the first tree

參照演算法原理的第一步,訓練了第一棵樹,並且將weight設為1,,然後計算錯誤率。呼叫了computeInitialPredictionAndError函式

  /**
   * :: DeveloperApi ::
   * Compute the initial predictions and errors for a dataset for the first
   * iteration of gradient boosting.
   * @param data: training data.
   * @param initTreeWeight: learning rate assigned to the first tree.
   * @param initTree: first DecisionTreeModel.
   * @param loss: evaluation metric.
   * @return a RDD with each element being a zip of the prediction and error
   *         corresponding to every sample.
   */
  @Since("1.4.0")
  @DeveloperApi
  def computeInitialPredictionAndError(
      data: RDD[LabeledPoint],
      initTreeWeight: Double,
      initTree: DecisionTreeModel,
      loss: Loss): RDD[(Double, Double)] = {
    data.map { lp =>
      val pred = initTreeWeight * initTree.predict(lp.features)
      val error = loss.computeError(pred, lp.label)
      (pred, error)
    }
  }

其中預測值直接使用DT的predict來預測,error使用loss的computeError函式,我們上面有介紹。

2.3.3. 迴圈訓練

2.3.3.1. 樣本處理

對應演算法的第2步,計算梯度,並且作為label更新樣本

val data = predError.zip(input).map { case ((pred, _), point) =>
        LabeledPoint(-loss.gradient(pred, point.label), point.features)
      }
2.3.3.2. 訓練樹

對應演算法的第3和第4步,用第2步的樣本作為輸入,訓練決策樹

val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
//       Technically, the weight should be optimized for the particular loss.
//       However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
2.3.3.3. 計算模型輸出

實際呼叫updatePredictionError函式,入參是原始的樣本,上一輪的錯誤率(實際包含上一輪的模型輸出),本來的決策樹,學習率和loss計算物件。

  /**
   * :: DeveloperApi ::
   * Update a zipped predictionError RDD
   * (as obtained with computeInitialPredictionAndError)
   * @param data: training data.
   * @param predictionAndError: predictionError RDD
   * @param treeWeight: Learning rate.
   * @param tree: Tree using which the prediction and error should be updated.
   * @param loss: evaluation metric.
   * @return a RDD with each element being a zip of the prediction and error
   *         corresponding to each sample.
   */
  @Since("1.4.0")
  @DeveloperApi
  def updatePredictionError(
    data: RDD[LabeledPoint],
    predictionAndError: RDD[(Double, Double)],
    treeWeight: Double,
    tree: DecisionTreeModel,
    loss: Loss): RDD[(Double, Double)] = {

    val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
      iter.map { case (lp, (pred, error)) =>
      //計算本輪模型的預測值
        val newPred = pred + tree.predict(lp.features) * treeWeight
        //計算本輪誤差
        val newError = loss.computeError(newPred, lp.label)
        //newPred是累計,包含至本輪的模型輸出
        (newPred, newError)
      }
    }
    newPredError
  }

程式碼中使用到的函式我們之前都有介紹。

2.3.3.3. validation(early stop)

類似計算錯誤率,只是樣本使用validationInput,看平均誤差是否減少,如果不能使誤差減小就結束訓練,相當於出現過擬合了;如果能,就繼續訓練,並且記錄最好的模型的index。這裡一次誤差變大就結束訓練比較武斷,最好應該有一定的閾值,避免單次訓練的波動。程式碼比較簡單,就不放了。

2.3.3.4. 訓練收尾

訓練完成後,根據記錄的最優模型的index,構造GradientBoostedTreesModel。

3.結語

從上面的分析可以看到,由於spark在Random Forest特徵方面的限制,以及GBDT實現中直接使用均值作為葉子節點的輸出值,early stop等,spark在樹模型上的精度可能會差一點,實際使用的話,最好與其他實現比較後決定是否使用。

相關推薦

spark mllib原始碼分析DecisionTreeGBDT

我們在前面的文章講過,在spark的實現中,樹模型的依賴鏈是GBDT-> Decision Tree-> Random Forest,前面介紹了最基礎的Random Forest的實現,在此基礎上我們介紹Decision Tree和GBDT的實現

spark mllib原始碼分析二分類邏輯迴歸evaluation

在邏輯迴歸分類中,我們評價分類器好壞的主要指標有精準率(precision),召回率(recall),F-measure,AUC等,其中最常用的是AUC,它可以綜合評價分類器效能,其他的指標主要偏重一些方面。我們介紹下spark中實現的這些評價指標,便於使用sp

spark mllib原始碼分析隨機森林(Random Forest)(二)

4. 特徵處理 這部分主要在DecisionTree.scala的findSplitsBins函式,將所有特徵封裝成Split,然後裝箱Bin。首先對split和bin的結構進行說明 4.1. 資料結構 4.1.1. Split cl

spark mllib原始碼分析L-BFGS(一)

1. 使用 spark給出的example中涉及到LBFGS有兩個,分別是LBFGSExample.scala和LogisticRegressionWithLBFGSExample.scala,第一個是直接使用LBFGS直接訓練,需要指定一系列優化引數,優

spark mllib原始碼分析隨機森林(Random Forest)(三)

6. 隨機森林訓練 6.1. 資料結構 6.1.1. Node 樹中的每個節點是一個Node結構 class Node @Since("1.2.0") ( @Since("1.0.0") val id: Int, @S

spark mllib原始碼分析邏輯迴歸彈性網路ElasticNet(一)

spark在ml包中將邏輯迴歸封裝了下,同時在演算法中引入了L1和L2正則化,通過elasticNetParam來調節兩種正則化的係數,同時根據選擇的正則化,決定使用L-BFGS還是OWLQN優化,是謂Elastic Net。 1. 輔助類 我們首先介紹

Spark core原始碼分析spark叢集的啟動(二)

2.2 Worker的啟動 org.apache.spark.deploy.worker 1 從Worker的伴生物件的main方法進入 在main方法中首先是得到一個SparkConf例項conf,然後將conf和啟動Worker傳入的引數封裝得到Wor

netty原始碼分析-SimpleChannelInboundHandlerChannelInboundHandlerAdapter詳解(6)

每一個Handler都一定會處理出站或者入站(也可能兩者都處理)資料,例如對於入站的Handler可能會繼承SimpleChannelInboundHandler或者ChannelInboundHandlerAdapter,而SimpleChannelIn

Spark SQL 原始碼分析Physical Plan 到 RDD的具體實現

  我們都知道一段sql,真正的執行是當你呼叫它的collect()方法才會執行Spark Job,最後計算得到RDD。 lazy val toRdd: RDD[Row] = executedPlan.execute()  Spark Plan基本包含4種操作型別,即Bas

Spark MLlib原始碼分析—Word2Vec原始碼詳解

以下程式碼是我依據SparkMLlib(版本1.6)中Word2Vec原始碼改寫而來,基本算是照搬。此版Word2Vec是基於Hierarchical Softmax的Skip-gram模型的實現。 在決定讀懂原始碼前,博主建議讀者先看一下《Word2Vec_

Prometheus 實戰於原始碼分析API聯邦

在進行原始碼講解關於prometheus還有一些配置和使用,需要解釋一下。首先是API的使用,prometheus提供了一套HTTP的介面 curl http://localhost:9090/api/v1/query?query=go_goroutine

Spark MLlib原始碼解讀樸素貝葉斯分類器,NaiveBayes

Spark MLlib 樸素貝葉斯NaiveBayes 原始碼分析 基本原理介紹 首先是基本的條件概率求解的公式。 P(A|B)=P(AB)P(B) 在現實生活中,我們經常會碰到已知一個條件概率,求得兩個時間交換後的概率的問題。也就是在已知P(A

Mybatis原始碼分析SpringMybatis整合MapperScannerConfigurer處理過程原始碼分析

        前面文章分析了這麼多關於Mybatis原始碼解析,但是我們最終使用的卻不是以前面文章的方式,編寫自己mybatis_config.xml,而是最終將配置融合在spring的配置檔案中。有了前面幾篇部落格的分析,相信這裡會容易理解些關於Mybatis的初始化及

Realm原始碼分析copyToRealmcopyToRealmOrUpdate

createObject 在Realm原始碼分析之Writes中已經詳細追蹤過createObject的執行流程,此處不再贅述。 createObject有如下的兩個過載方法,區別是如果Model沒有指明主鍵使用前者,否則使用後者: createObjec

Spark MLlib原始碼分析—TFIDF原始碼詳解

以下程式碼是我依據SparkMLlib(版本1.6) 1、HashingTF 是使用雜湊表來儲存分詞,並計算分詞頻數(TF),生成HashMap表。在Map中,K為分詞對應索引號,V為分詞的頻數。在宣告HashingTF 時,需要設定numFeatures,該

netty原始碼分析-EventLoop執行緒模型(1)

執行緒模型確定來程式碼的執行方式,我們總是必須規避併發執行可能會帶來的副作用,所以理解netty所採用的併發模型的影響很重要。netty使用了被稱為事件迴圈的EventLoop來執行任務來處理在連線的生命週期內發生的事件 執行緒模型 對於Even

Mybatis深入原始碼分析Mapper介面繫結原理原始碼分析

緊接上篇文章:Mybatis深入原始碼分析之SqlSessionFactoryBuilder原始碼分析,這裡再來分析下,Mappe

Spark原始碼分析Spark Shell(上)

https://www.cnblogs.com/xing901022/p/6412619.html 文中分析的spark版本為apache的spark-2.1.0-bin-hadoop2.7。 bin目錄結構: -rwxr-xr-x. 1 bigdata bigdata 1089 Dec

symfony原始碼分析容器的生成使用

symfony 的容器是有一個編譯過程的,框架初始化的時候會執行Symfony\Component\HttpKernel\Kernel::initializationContainer ,這個方法會對程式碼進行檢查,看是否需要生成新的容器程式碼。如果需要 Symfony 會將各個類的依賴關係通過

Spark——Streaming原始碼解析資料的產生匯入

此文是從思維導圖中匯出稍作調整後生成的,思維腦圖對程式碼瀏覽支援不是很好,為了更好閱讀體驗,文中涉及到的原始碼都是刪除掉不必要的程式碼後的虛擬碼,如需獲取更好閱讀體驗可下載腦圖配合閱讀: 此博文共分為四個部分: DAG定義 Job動態生成 資料的產生與匯入 容錯 資料的產生與匯入主要分為以下五個部分