1. 程式人生 > >Aprior並行化演算法在Spark上的實現

Aprior並行化演算法在Spark上的實現

本文為大家分享的Spark實戰案例是K-頻繁項集挖掘——Apriori並行化演算法的實現。關聯資料探勘、頻繁項集挖掘的常用演算法有Apriori,Fp-growth以及eclat演算法。這裡我使用Apriori演算法進行頻繁項集挖掘。Apriori演算法於2006年12月被國際權威的學術組織ICDM評為資料探勘領域的十大經典演算法。不熟悉的同學可以關注我的文章,我會詳細講解其原理及實現。

首先給出需求說明:在Chess標準資料集上進行1到8頻繁項集的挖掘,其中支援度support=0.85。每個檔案的輸出格式為項集:頻率,如a,b,c:0.85。

我們在寫Spark程式的時候一定要注意寫出的程式是並行化的,而不是隻在client上執行的單機程式。否則你的演算法效率將讓你跌破眼鏡而你還在鬱悶為什麼Spark這麼慢甚至比不上Hadoop-MR。此外還需要對演算法做相關優化。在這裡主要和大家交流一下演算法思路和相關優化。

對於Apriori演算法的實現見下文原始碼。在Spark上實現這個演算法的時候主要分為兩個階段。第一階段是一個整體的遍歷求出每個項集的階段,第二階段主要是針對第i個項集求出第i+1項集的候選集的階段。

對於這個演算法可以做如下優化:
1. 觀察!這點很重要,經過觀察可以發現有大量重複的資料,所謂方向不對努力白費也是這個道理,首先需要壓縮重複的資料。不然會做許多無用功。
2. 設計演算法的時候一定要注意是並行化的,大家可能很疑惑,Spark不就是並行化的麼?可是你一不小心可能就寫成只在client端執行的演算法了。
3. 因為資料量比較大,切記多使用資料持久化以及BroadCast廣播變數對中間資料進行相應處理。
4. 資料結構的優化,BitSet是一種優秀的資料結構他只需一位就可以儲存以個整形數,對於所給出的資料都是整數的情況特別適用。
下面給出演算法實現原始碼:

import scala.util.control.Breaks._
import scala.collection.mutable.ArrayBuffer
import java.util.BitSet
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark._


object FrequentItemset {
  def main(args: Array[String]) {
    if (args.length != 2) {
      println("USage:<Datapath> <Output>"
) } //initial SparkContext val sc = new SparkContext() val SUPPORT_NUM = 15278611 //Transactions total is num=17974836, SUPPORT_NUM = num*0.85 val TRANSACITON_NUM = 17974836.0 val K = 8 //All transactions after removing transaction ID, and here we combine the same transactions. val transactions = sc.textFile(args(0)).map(line => line.substring(line.indexOf(" ") + 1).trim).map((_, 1)).reduceByKey(_ + _).map(line => { val bitSet = new BitSet() val ss = line._1.split(" ") for (i <- 0 until ss.length) { bitSet.set(ss(i).toInt, true) } (bitSet, line._2) }).cache() //To get 1 frequent itemset, here, fi represents frequent itemset var fi = transactions.flatMap { line => val tmp = new ArrayBuffer[(String, Int)] for (i <- 0 until line._1.size()) { if (line._1.get(i)) tmp += ((i.toString, line._2)) } tmp }.reduceByKey(_ + _).filter(line1 => line1._2 >= SUPPORT_NUM).cache() val result = fi.map(line => line._1 + ":" + line._2 / TRANSACITON_NUM) result.saveAsTextFile(args(1) + "/result-1") for (i <- 2 to K) { val candiateFI = getCandiateFI(fi.map(_._1).collect(), i) val bccFI = sc.broadcast(candiateFI) //To get the final frequent itemset fi = transactions.flatMap { line => val tmp = new ArrayBuffer[(String, Int)]() //To check if each itemset of candiateFI in transactions bccFI.value.foreach { itemset => val itemArray = itemset.split(",") var count = 0 for (item <- itemArray) if (line._1.get(item.toInt)) count += 1 if (count == itemArray.size) tmp += ((itemset, line._2)) } tmp }.reduceByKey(_ + _).filter(_._2 >= SUPPORT_NUM).cache() val result = fi.map(line => line._1 + ":" + line._2 / TRANSACITON_NUM) result.saveAsTextFile(args(1) + "/result-" + i) bccFI.unpersist() } } //To get the candiate k frequent itemset from k-1 frequent itemset def getCandiateFI(f: Array[String], tag: Int) = { val separator = "," val arrayBuffer = ArrayBuffer[String]() for(i <- 0 until f.length;j <- i + 1 until f.length){ var tmp = "" if(2 == tag) tmp = (f(i) + "," + f(j)).split(",").sortWith((a,b) => a.toInt <= b.toInt).reduce(_+","+_) else { if (f(i).substring(0, f(i).lastIndexOf(',')).equals(f(j).substring(0, f(j).lastIndexOf(',')))) { tmp = (f(i) + f(j).substring(f(j).lastIndexOf(','))).split(",").sortWith((a, b) => a.toInt <= b.toInt).reduce(_ + "," + _) } } var hasInfrequentSubItem = false //To filter the item which has infrequent subitem if (!tmp.equals("")) { val arrayTmp = tmp.split(separator) breakable { for (i <- 0 until arrayTmp.size) { var subItem = "" for (j <- 0 until arrayTmp.size) { if (j != i) subItem += arrayTmp(j) + separator } //To remove the separator "," in the end of the item subItem = subItem.substring(0, subItem.lastIndexOf(separator)) if (!f.contains(subItem)) { hasInfrequentSubItem = true break } } } //breakable } else hasInfrequentSubItem = true //If itemset has no sub inftequent itemset, then put it into candiateFI if (!hasInfrequentSubItem) arrayBuffer += (tmp) } //for arrayBuffer.toArray } }

在這裡提一下我的實驗結果以便大家參考,對於2G,1800W條記錄的資料,在80G記憶體,10個虛擬節點的叢集上18秒就算完了1-8頻繁項集的挖掘。應該還算不錯。

先寫到這裡,歡迎大家提出相關的建議或意見。
(by希慕,轉載請註明出處)