1. 程式人生 > >spark原始碼解讀2之水塘抽樣演算法(Reservoir Sampling)

spark原始碼解讀2之水塘抽樣演算法(Reservoir Sampling)

spark原始碼解讀系列環境:spark-1.5.2、hadoop-2.6.0、scala-2.10.4

1.理解

  問題定義可以簡化如下:在不知道檔案總行數的情況下,如何從檔案中隨機的抽取一行?

  首先想到的是我們做過類似的題目嗎?當然,在知道檔案行數的情況下,我們可以很容易的用C執行庫的rand函式隨機的獲得一個行數,從而隨機的取出一行,但是,當前的情況是不知道行數,這樣如何求呢?我們需要一個概念來幫助我們做出猜想,來使得對每一行取出的概率相等,也即隨機。這個概念即蓄水池抽樣(Reservoir Sampling)。

水塘抽樣演算法(Reservoir Sampling)思想:

  在序列流中取一個數,如何確保隨機性,即取出某個資料的概率為:1/(已讀取資料個數)

  假設已經讀取n個數,現在保留的數是Ax,取到Ax的概率為(1/n)。

  對於第n+1個數An+1,以1/(n+1)的概率取An+1,否則仍然取Ax。依次類推,可以保證取到資料的隨機性。

  數學歸納法證明如下:

    當n=1時,顯然,取A1。取A1的概率為1/1。

       假設當n=k時,取到的資料Ax。取Ax的概率為1/k。

       當n=k+1時,以1/(k+1)的概率取An+1,否則仍然取Ax。

    (1)如果取Ak+1,則概率為1/(k+1);

    (2)如果仍然取Ax,則概率為(1/k)*(k/(k+1))=1/(k+1)

  所以,對於之後的第n+1個數An+1,以1/(n+1)的概率取An+1,否則仍然取Ax。依次類推,可以保證取到資料的隨機性。

在序列流中取k個數,如何確保隨機性,即取出某個資料的概率為:k/(已讀取資料個數)

  建立一個數組,將序列流裡的前k個數,儲存在陣列中。(也就是所謂的”蓄水池”)

  對於第n個數An,以k/n的概率取An並以1/k的概率隨機替換“蓄水池”中的某個元素;否則“蓄水池”陣列不變。依次類推,可以保證取到資料的隨機性。

  數學歸納法證明如下:

    當n=k是,顯然“蓄水池”中任何一個數都滿足,保留這個數的概率為k/k。

    假設當n=m(m>k)時,“蓄水池”中任何一個數都滿足,保留這個數的概率為k/m。
    當n=m+1時,以k/(m+1)的概率取An,並以1/k的概率,隨機替換“蓄水池”中的某個元素,否則“蓄水池”陣列不變。則陣列中保留下來的數的概率為:

  所以,對於第n個數An,以k/n的概率取An並以1/k的概率隨機替換“蓄水池”中的某個元素;否則“蓄水池”陣列不變。依次類推,可以保證取到資料的隨機性。

Spark中的水塘抽樣演算法(Reservoir Sampling)

  spark的Partitioner子類RangePartitioner中有用到Reservoir Sampling抽樣演算法(org.apache.spark.RangePartitioner#sketch).

spark的util中有reservoirSampleAndCount方法(org.apache.spark.util.random.SamplingUtils#reservoirSampleAndCount)

原始碼為:

 /**
   * Reservoir sampling implementation that also returns the input size.
   *
   * @param input input size
   * @param k reservoir size
   * @param seed random seed
   * @return (samples, input size)
   */
  def reservoirSampleAndCount[T: ClassTag](
      input: Iterator[T],
      k: Int,
      seed: Long = Random.nextLong())
    : (Array[T], Int) = {
    val reservoir = new Array[T](k)
    // Put the first k elements in the reservoir.
    var i = 0
    while (i < k && input.hasNext) {
      val item = input.next()
      reservoir(i) = item
      i += 1
    }

    // If we have consumed all the elements, return them. Otherwise do the replacement.
    if (i < k) {
      // If input size < k, trim the array to return only an array of input size.
      val trimReservoir = new Array[T](i)
      System.arraycopy(reservoir, 0, trimReservoir, 0, i)
      (trimReservoir, i)
    } else {
      // If input size > k, continue the sampling process.
      val rand = new XORShiftRandom(seed)
      while (input.hasNext) {
        val item = input.next()
        val replacementIndex = rand.nextInt(i)
        if (replacementIndex < k) {
          reservoir(replacementIndex) = item
        }
        i += 1
      }
      (reservoir, i)
    }
  }

程式碼實現思路比較簡單,新建一個k大小的陣列reservoir,如果元資料中資料少於k,直接返回原資料陣列和原資料個數。如果大於,則對接下來的元素進行比較,隨機生成一個數i,如果這個數小於k,則替換陣列reservoir中第i個數,直至沒有元素,則返回reservoir的copy陣列。

2.程式碼:

測試org.apache.spark.util.random.SamplingUtils$#reservoirSampleAndCount方法:

package org.apache.spark.sourceCode.partitionerLearning

import org.apache.spark.util.SparkLearningFunSuite
import org.apache.spark.util.random.SamplingUtils

import scala.util.Random

/**
  * Created by xubo on 2016/10/9.
  */
class reservoirSampleAndCountSuite extends SparkLearningFunSuite {
  test("reservoirSampleAndCount") {
    val input = Seq.fill(100)(Random.nextInt())
    val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
    assert(count1 === 100)
    assert(input === sample1.toSeq)

    // input size == k
    val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
    assert(count2 === 100)
    assert(input === sample2.toSeq)

    // input size > k
    val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
    assert(count3 === 100)
    assert(sample3.length === 10)
    println(input)
    sample3.foreach{each=>print(each+" ")}
  }

}

3.結果:

List(1287104639, 547232730, -595310393, -1264894486, 427750044, -776002403, 32230947, -1390386390, 484259687, 774711013, -1989325813, -957970416, 945685455, -1322730587, -1919655222, 1642426087, -489524599, -1070401860, -1454008456, -1882431453, -1843815884, -1987533758, -854529853, 879991257, -864077378, 478381860, 111307761, 1504756336, -1892792571, -1413976846, -848218587, -101494119, 1592476609, 247606007, 1269634528, 568928892, 488930464, -2145986422, 1643110602, 280675891, -878405966, 1799740067, 981424562, -1552824965, -1760162041, -288189264, -373755181, -2112636248, -2108911467, -1815555415, 302051417, 254178521, -1137490849, 426066017, -819810525, 1408383341, 1183678420, 234717727, 1470632905, 271163573, -22448780, 486064749, 378168799, -1444541974, 419089337, 1700972847, 1291787131, 644012641, -1618133452, 313585654, 658987252, 869334013, -811750155, -1561229418, 814819564, -197177628, 1051344432, 1746109173, 358985873, -265551244, 1362130460, -1635943643, 168813599, -669120136, -1084593890, -150445899, 387678120, 1994726806, 71986215, 1323527748, 700729367, 219285004, -1513691303, -97767338, 2099894386, -652208741, 704524016, 123647594, -1281589410, -1713105982)
-197177628 -22448780 478381860 -1137490849 219285004 168813599 1269634528 -1454008456 658987252 378168799 

參考

【1】http://spark.apache.org/
【2】http://spark.apache.org/docs/1.5.2/programming-guide.html
【3】https://github.com/xubo245/SparkLearning
【4】book:《深入理解spark核心思想與原始碼分析》
【5】book:《spark核心原始碼分析和開發實戰》
【6】http://www.cnblogs.com/xudong-bupt/p/4053652.html
【7】http://www.cnblogs.com/HappyAngel/archive/2011/02/07/1949762.html
【8】https://www.iteblog.com/archives/1525