1. 程式人生 > >深入理解Spark ML:多項式樸素貝葉斯原理與原始碼分析

深入理解Spark ML:多項式樸素貝葉斯原理與原始碼分析

貝葉斯估計

如果一個給定的類和特徵值在訓練集中沒有一起出現過,那麼基於頻率的估計下該概率將為0。這將是一個問題。因為與其他概率相乘時將會把其他概率的資訊統統去除。所以常常要求要對每個小類樣本的概率估計進行修正,以保證不會出現有為0的概率出現。常用到的平滑就是加1平滑(也稱拉普拉斯平滑):

P(Xj=ajl|Y=ck)=Ni=1I(x(j)i=ajl,yi=ck)+lambdaNi=1I(yi=ck)+Sjlambda

lambda>=0,等價於在隨機變數各個取值的頻數上賦予一個正數lambda>0。Sj是特徵Xj取值的類別數,因此使用上式依然有:

Sjl=1P(Xj=ajl|

Y=ck)=1

同樣的:

P(Y=ck)=Ni=1I(yi=ck)+lambdaN+Klambda

N為資料條數,K為label類別數。

多項式樸素貝葉斯

多項式樸素貝葉斯和上述貝葉斯模型不同的是,上述貝葉斯模型對於某特徵的不同取值代表著不同的類別,而多項式樸素貝葉斯對於某特徵的不同取值代表著該特徵決定該label類別的重要程度。

比如一個文字中,單詞Chinese出現的頻數,1次還是10次,並不代表著Chinese單詞這個特徵的類別,而代表著Chinese單詞這個特徵的決定該文字label類別的重要程度。

log(p(yi))=log(Ni=1I(yi=ck)+lambd

a)log(N+Klambda)

log(P(aj|yi))=log(Ni=1aj,yi=ck+lambda)log(Ni=1nj=1aj,yi=ck+nlambda)

n為特徵維度數

我們來舉個例子:

這裡寫圖片描述

我們設lambda為1,共有6個不同的單詞,則特徵維度數為6。

這裡寫圖片描述

這裡寫圖片描述

這裡寫圖片描述

這裡寫圖片描述

所以,我們將d5 分類到 yes

API 使用

下面是Spark 樸素貝葉斯的使用例子:

import org.apache.spark.ml.classification.NaiveBayes

// 載入資料
val data = spark.read.format("libsvm"
).load("data/mllib/sample_libsvm_data.txt") // 切分資料集與訓練集 val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L) // 訓練樸素貝葉斯模型 val model = new NaiveBayes() .fit(trainingData) // 預測 val predictions = model.transform(testData) predictions.show()

原始碼分析

接下來我們來分析下原始碼~

NaiveBayes

train

NaiveBayes().fit呼叫NaiveBayes的父類Predictor中的fit,將labelweight轉為Double,儲存labelweight原資訊,最後呼叫NaiveBayestrain

  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
    trainWithLabelCheck(dataset, positiveLabel = true)
  }

trainWithLabelCheck:

ml假設輸入labels範圍在[0, numClasses). 但是這個實現也被mllib NaiveBayes呼叫,它允許其他型別的輸入labels如{-1, +1}. positiveLabel 用於確定label是否需要被檢查。

 private[spark] def trainWithLabelCheck(
      dataset: Dataset[_],
      positiveLabel: Boolean): NaiveBayesModel = {
      //檢測label
    if (positiveLabel && isDefined(thresholds)) {
      val numClasses = getNumClasses(dataset)
      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
        ".train() called with non-matching numClasses and thresholds.length." +
        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
    }
    //模型型別 多項式樸素貝葉斯是  Multinomial
    val modelTypeValue = $(modelType)
    val requireValues: Vector => Unit = {
      modelTypeValue match {
        case Multinomial =>
          // 確認所有的值非負
          // values.forall(_ >= 0.0)
          requireNonnegativeValues
        ......
      }
    }
    // Instrumentation 是 一個小封裝,用來定義為一個estimator定義一個training session和該session中有學用的資訊的log方法
    val instr = Instrumentation.create(this, dataset)
    instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
      probabilityCol, modelType, smoothing, thresholds)
    // 得到特徵維度數,即公式中的 n
    val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
    instr.logNumFeatures(numFeatures)
    // 得到記錄的權重 為設定 預設為 1.0
    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))

    // 聚合
    val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
      .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
      // 根據key labelCol 進行聚合
      // value 的初始值為 0.0,Vectors.zeros(numFeatures).toDense
      }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
      // 合併在同一