Spark MLlib 之 aggregate和treeAggregate從原理到應用
阿新 • • 發佈:2018-07-09
數據量 hash oom 向上 gre require 圖片 iterator reac
在閱讀spark mllib源碼的時候,發現一個出鏡率很高的函數——aggregate和treeAggregate,比如matrix.columnSimilarities()中。為了好好理解這兩個方法的使用,於是整理了本篇內容。
由於treeAggregate是在aggregate基礎上的優化版本,因此先來看看aggregate是什麽.
更多內容參考我的大數據學習之路
aggregate
先直接看一下代碼例子:
import org.apache.spark.sql.SparkSession object AggregateTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate() spark.sparkContext.setLogLevel("WARN") // 創建rdd,並分成6個分區 val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6) // 輸出每個分區的內容 rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{ Array((s" $index : ${it.toList.mkString(",")}")).toIterator }).foreach(println) // 執行agg val res1 = rdd.aggregate(0)(seqOp, combOp) } // 分區內執行的方法,直接加和 def seqOp(s1:Int, s2:Int):Int = { println("seq: "+s1+":"+s2) s1 + s2 } // 在driver端匯總 def combOp(c1: Int, c2: Int): Int = { println("comb: "+c1+":"+c2) c1 + c2 } }
這段代碼的主要目的就是為了求和。考慮到spark分區並行計算的特性,在每個分區獨立加和,最後再匯總加和。
過程可以參考下面的圖片:
首先看一下map階段,即在每個分區內計算加和。初始情況如藍色方塊所示,內容為:
分區號:裏面的內容
如,0分區內的數據為6和8
當執行seqop時,會說先用初始值0開始遍歷累加,原理類似如下:
rdd.mapPartitions((it:Iterator)=>{
var sum = init_value // 默認為0
it.foreach(sum + _)
sum
})
因此屏幕上會出現下面的內容,由於分區之間是並行的,所以最後的結果是亂序的:
seq: 0:6
seq: 0:1
seq: 0:3
seq: 1:9
seq: 3:10
seq: 0:2
seq: 0:5
seq: 5:7
seq: 12:12
seq: 0:4
seq: 4:11
seq: 6:8
計算完成後,依次遍歷每個分區結果,進行累加:
comb: 0:10
comb: 10:13
comb: 23:2
comb: 25:24
comb: 49:15
comb: 64:14
aggregate的源碼也比較簡單:
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope { var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) sc.runJob(this, aggregatePartition, mergeResult) jobResult }
treeAggregate
treeAggregate在aggregate的基礎上做了一些優化,因為aggregate是在每個分區計算完成後,把所有的數據拉倒driver端,進行統一的遍歷合並,這樣如果數據量很大,在driver端可能會OOM。
因此treeAggregate在中間多加了一層合並。
先來看看代碼,沒有任何的變化:
import org.apache.spark.sql.SparkSession
object TreeAggregateTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
Array(s" $index : ${it.toList.mkString(",")}").toIterator
}).foreach(println)
val res1 = rdd.treeAggregate(0)(seqOp, combOp)
println(res1)
}
def seqOp(s1:Int, s2:Int):Int = {
println("seq: "+s1+":"+s2)
s1 + s2
}
def combOp(c1: Int, c2: Int): Int = {
println("comb: "+c1+":"+c2)
c1 + c2
}
}
輸出的結果則發生了變化,首先分區內的操作不變:
3 : 3,10
2 : 2
0 : 6,8
1 : 1,9
4 : 4,11
5 : 5,7,12
seq: 0:3
seq: 0:6
seq: 3:10
seq: 6:8
seq: 0:2
seq: 0:1
seq: 1:9
seq: 0:4
seq: 4:11
seq: 0:5
seq: 5:7
seq: 12:12
...
在合並的時候發生了 變化:
comb: 10:13
comb: 23:24
comb: 14:2
comb: 16:15
comb: 47:31
配合下面的流程圖,可以更好的理解:
搭配treeAggregate的源碼來看一下:
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = withScope {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.length == 0) {
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
} else {
// 這裏都沒什麽變化,在分區中遍歷數據累加
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition =
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
// 關鍵是這下面的內容 !!!!
// 首先獲得當前的分區數
var numPartitions = partiallyAggregated.partitions.length
// 計算合適的並行度,我這裏相當於6^(1/2),也就是2.4左右,ceill向上取整後變成3.
// max(3,2)得到最後的結果為3。即每個樹的分枝有3個葉子節點
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// 遍歷分區,通過對scale取模進行合並計算
// 這裏判斷一下,當前的分區數是否還夠分。如果少於條件值 scale+(p/scale),就停止分區
while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
// 重新定義分區id,並按照分區id重新分區,執行合並計算
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
// 最後統計結果
partiallyAggregated.reduce(cleanCombOp)
}
}
spark中的應用
// matrix求相似度
def columnSimilarities(threshold: Double): CoordinateMatrix = {
... columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
}
// 統計每一個向量的相關數據,裏面包含了min max 等等很多信息
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
updateNumRows(summary.count)
summary
}
了解了treeAggregate之後,後續就可以看matrix的並行求解相似度的源碼了!敬請期待吧...
參考
- spark-aggregate與treeAggregate的理解
Spark MLlib 之 aggregate和treeAggregate從原理到應用