1. 程式人生 > >Spark MLlib原始碼分析—Word2Vec原始碼詳解

Spark MLlib原始碼分析—Word2Vec原始碼詳解

以下程式碼是我依據SparkMLlib(版本1.6)中Word2Vec原始碼改寫而來,基本算是照搬。此版Word2Vec是基於Hierarchical Softmax的Skip-gram模型的實現。
在決定讀懂原始碼前,博主建議讀者先看一下《Word2Vec_中的數學原理詳解》或者看本人根據這篇文件做的一個摘要總結:
http://blog.csdn.net/liuyuemaicha/article/details/52611219
Ps* 程式碼註解的很詳細了,閱讀程式碼請從類CWord2Vec的fit函式開始


import java.nio.ByteBuffer
import java.util.{Random => JavaRandom}

import
com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import scala.collection.mutable import scala.util.hashing.MurmurHash3 /** * Entry in vocabulary */ private case class VocabWord( var word: String, //分詞 var cn: Int,//計數 var point: Array[Int], //儲存路徑,即經過得結點 var code: Array[Int], //記錄Huffman編碼 var codeLen: Int ////儲存到達該葉子結點,要經過多少個結點 )
class CWord2Vec extends Serializable{ private val random = new JavaRandom() private var seed = new JavaRandom().nextLong() private var vectorSize = 100 //向量大小 private var learningRate = 0.025 //學習率 private var numPartitions = 1 private var numIterations = 60 //迭代次數 private var minCount = 5 //關鍵詞的上下視窗
private var maxSentenceLength = 1000 //每條語句以長度maxSentenceLength分組 private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 /** context words from [-window, window] */ private var window = 5 private var trainWordsCount = 0L private var vocabSize = 0 private var vocab: Array[VocabWord] = null private var vocabHash = mutable.HashMap.empty[String, Int] /* 詞典構建 */ private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { val words = dataset.flatMap(x => x) vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) //分詞計數 .filter(_._2 >= minCount)//過濾頻數少於minCount的分詞 .map(x => VocabWord( x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) .collect() .sortWith((a, b) => a.cn > b.cn) //按頻數從大到小排序 vocabSize = vocab.length //詞典的元素個數 require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + "the setting of minCount, which could be large enough to remove all your words in sentences.") var a = 0 while (a < vocabSize) { vocabHash += vocab(a).word -> a //生成hashMap(K:word,V:a)--> 對詞典中所有元素進行對映,方便查詢 trainWordsCount += vocab(a).cn //計算語料C中分詞的數量 a += 1 } //logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") } /* Create Huffman Tree */ private def createBinaryTree(): Unit = { val count = new Array[Long](vocabSize * 2 + 1) //二叉樹中所有的結點 val binary = new Array[Int](vocabSize * 2 + 1)//設定每個結點的Huffman編碼:左1,右0 val parentNode = new Array[Int](vocabSize * 2 + 1)//儲存每個結點的父節點 val code = new Array[Int](MAX_CODE_LENGTH)//儲存每個葉子結點的Huffman編碼 val point = new Array[Int](MAX_CODE_LENGTH)//儲存每個葉子結點的路徑(經歷過哪些結點) var a = 0 while (a < vocabSize) { count(a) = vocab(a).cn //初始化葉子結點,以頻數作為權值,葉子:0~vocabSize-1 a += 1 } while (a < 2 * vocabSize) { count(a) = 1e9.toInt //10的9次方,非葉子結點,初始化為最大值 a += 1 } var pos1 = vocabSize - 1 var pos2 = vocabSize var min1i = 0 var min2i = 0 a = 0 while (a < vocabSize - 1) { //構造Huffman樹 if (pos1 >= 0) { if (count(pos1) < count(pos2)) { min1i = pos1 pos1 -= 1 } else { min1i = pos2 pos2 += 1 } } else { min1i = pos2 pos2 += 1 } if (pos1 >= 0) { if (count(pos1) < count(pos2)) { min2i = pos1 pos1 -= 1 } else { min2i = pos2 pos2 += 1 } } else { min2i = pos2 pos2 += 1 } count(vocabSize + a) = count(min1i) + count(min2i) parentNode(min1i) = vocabSize + a parentNode(min2i) = vocabSize + a binary(min2i) = 1 a += 1 } // Now assign binary code to each vocabulary word var i = 0 a = 0 while (a < vocabSize) { var b = a i = 0 while (b != vocabSize * 2 - 2) { //vocabSize * 2 - 2 表示根結點 code(i) = binary(b) //第b個結點的Huffman編碼是0 or 1 point(i) = b //儲存路徑,經過b結點 i += 1 b = parentNode(b) } vocab(a).codeLen = i //儲存到達葉子結點a,要經過多少個結點 vocab(a).point(0) = vocabSize - 2 b = 0 while (b < i) { vocab(a).code(i - b - 1) = code(b) ////記錄Huffman編碼 vocab(a).point(i - b) = point(b) - vocabSize //記錄經過的結點 b += 1 } a += 1 } } //建立sigmoid函式查詢表 private def createExpTable(): Array[Float] = { //初始化ExpTable,初始化引數為0-999的e值 val expTable = new Array[Float](EXP_TABLE_SIZE) var i = 0 while (i < EXP_TABLE_SIZE) { val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) expTable(i) = (tmp / (tmp + 1.0)).toFloat i += 1 } expTable } def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { learnVocab(dataset) //構建詞典 createBinaryTree() //構建 Huffman 樹 val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => // Each sentence will map to 0 or more Array[Int] sentenceIter.flatMap { sentence => val wordIndexes = sentence.flatMap(bcVocabHash.value.get)// 將分詞轉化為對應的目錄值(index) wordIndexes.grouped(maxSentenceLength).map(_.toArray) //一條語句長度大於1000後,將被拆分為多個分組 } } val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) if (vocabSize.toLong * vectorSize >= Int.MaxValue) { throw new RuntimeException("vocabSize.toLong * vectorSize >= Int.MaxValue, " + "Int.MaxValue: " + Int.MaxValue) } //初始化葉子節點,分詞向量隨機設定初始值 val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) //初始化非葉子結點,引數向量設定初始值為0 val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate //學習率 for (k <- 1 to numIterations){ //對整個語料開始迭代,總共完成numIterations次迭代 val bcSyn0Global = sc.broadcast(syn0Global) val bcSyn1Global = sc.broadcast(syn1Global) //對每條句子進行向量計算:case中idx表示分詞的目錄,iter表示這條句子的起始地址 val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount // TODO: discount by iteration? alpha = learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 //logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } wc += sentence.length var pos = 0 while (pos < sentence.length) { val word = sentence(pos) //這條句子中第pos個分詞 //在window範圍內隨機取出一個詞b window 表示中心詞w上下最大各window個詞。 // 則最多一共2*window個詞,即Context(w)的長度最大為2*window val b = random.nextInt(window) // Train Skip-gram var a = b while (a < window * 2 + 1 - b) {//此處迴圈是以pos為中心的skip-gram,即Context(w)中分詞的向量計算 if (a != window) { val c = pos - window + a //c 是以 pos 為中心,所要表徵Context(w)中的一個分詞 if (c >= 0 && c < sentence.length) { val lastWord = sentence(c) //c是通過pos詞得到的,即Huffman樹的葉子結點,也就是lastWord val l1 = lastWord * vectorSize val neu1e = new Array[Float](vectorSize) //用來儲存Context(w)中各分詞向量對分詞w的貢獻向量值 // Hierarchical softmax var d = 0 //Huffman樹中到達單詞word,要經過結點數為 codeLen,這裡從根節點開始遍歷Huffman樹 while (d < bcVocab.value(word).codeLen) { val inner = bcVocab.value(word).point(d) //經過第d步時的結點 val l2 = inner * vectorSize // Propagate hidden -> output var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)//syn0 * syn1 兩向量相乘 if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) //neu1e = g * syn1 + neu1e blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) //syn1 = g * syn0 + syn1 syn1Modify(inner) += 1 } d += 1 } blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) //syn0 = 1.0f * neu1e + syn0 syn0Modify(lastWord) += 1 } } a += 1 } pos += 1 } (syn0, syn1, lwc, wc) } val syn0Local = model._1 //syn0 為葉子結點向量,即分詞向量 val syn1Local = model._2 //syn1 為非葉子結點向量,即引數向量 // Only output modified vectors. Iterator.tabulate(vocabSize) { index => if (syn0Modify(index) > 0) { Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) } else { None } }.flatten ++ Iterator.tabulate(vocabSize) { index => if (syn1Modify(index) > 0) { Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) } else { None } }.flatten } //處理完每條句子的向量後,對所有語句中相同分詞所對應的向量相加 val synAgg = partial.reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) //v2 + v1 v1 }.collect() var i = 0 while (i < synAgg.length) { val index = synAgg(i)._1 if (index < vocabSize) { Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) } else { Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) } i += 1 } bcSyn0Global.unpersist(false) bcSyn1Global.unpersist(false) } newSentences.unpersist() expTable.unpersist() bcVocab.unpersist() bcVocabHash.unpersist() val wordArray = vocab.map(_.word) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) } } class Word2VecModel ( val wordIndex: Map[String, Int], val wordVectors: Array[Float]) extends Serializable { private val numWords = wordIndex.size private val vectorSize = wordVectors.length / numWords private val wordList: Array[String] = { val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip wl.toArray } private val wordVecNorms: Array[Double] = { val wordVecNorms = new Array[Double](numWords) var i = 0 while (i < numWords) { val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) i += 1 } wordVecNorms } def transform(word: String): Vector = { wordIndex.get(word) match { case Some(ind) => val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) Vectors.dense(vec.map(_.toDouble)) case None => throw new IllegalStateException(s"$word not in vocabulary") } } def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) findSynonyms(vector, num) } def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 val beta: Float = 0 // Normalize input vector before blas.sgemv to avoid Inf value val vecNorm = blas.snrm2(vectorSize, fVector, 1) if (vecNorm != 0.0f) { blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1) } blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) val cosVec = cosineVec.map(_.toDouble) var ind = 0 while (ind < numWords) { val norm = wordVecNorms(ind) if (norm == 0.0) { cosVec(ind) = 0.0 } else { cosVec(ind) /= norm } ind += 1 } wordList.zip(cosVec) .toSeq .sortBy(-_._2) .take(num + 1) .tail .toArray } } private class XORShiftRandom(init: Long) extends JavaRandom(init) { private var seed = hashSeed(init) private def hashSeed(seed: Long): Long = { val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() val lowBits = MurmurHash3.bytesHash(bytes) val highBits = MurmurHash3.bytesHash(bytes, lowBits) (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL) } // we need to just override next - this will be called by nextInt, nextDouble, // nextGaussian, nextLong, etc. override protected def next(bits: Int): Int = { var nextSeed = seed ^ (seed << 21) nextSeed ^= (nextSeed >>> 35) nextSeed ^= (nextSeed << 4) seed = nextSeed (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] } }