1. 程式人生 > >Spark MLlib原始碼分析—TFIDF原始碼詳解

Spark MLlib原始碼分析—TFIDF原始碼詳解

以下程式碼是我依據SparkMLlib(版本1.6)
1、HashingTF 是使用雜湊表來儲存分詞,並計算分詞頻數(TF),生成HashMap表。在Map中,K為分詞對應索引號,V為分詞的頻數。在宣告HashingTF 時,需要設定numFeatures,該屬性實為設定雜湊表的大小;如果設定numFeatures過小,則在儲存分詞時會出現重疊現象,所以不要設定太小,一般情況下設定為30w~50w之間。
2、IDF是計算每個分詞出現在文章中的次數,並計算log值。在宣告IDF時,可以設定minDocFreq,即過濾掉出現文章數小於minDocFreq的分詞。
3、IDFModel 主要是計算TF*IDF,另外IDFModel也可以將IDF資料儲存下來(即模型的儲存),在測試語料時,只需要計算測試語料中每個分詞的在該篇文章中的詞頻TF,就可以計算TFIDF。

package org.apache.spark.mllib.feature
class HashingTF(val numFeatures: Int) extends Serializable {
  def this() = this(1 << 20)

  def nonNegativeMod(x: Int, mod: Int): Int = { //根據 numFeatures 設定的雜湊表容量,來設定索引號
    val rawMod = x % mod
    rawMod + (if (rawMod < 0) mod else 0)
  }
  def indexOf(term: Any): Int = nonNegativeMod(term.##, numFeatures) //根據分詞來生成索引號
def transform(document: Iterable[_]): Vector = { //每篇文章一個hash表,記錄每篇文章中的詞頻 val termFrequencies = mutable.HashMap.empty[Int, Double] document.foreach { term => val i = indexOf(term) //map中的getOrElse(i, 0.0)函式表示如果找到i位置的值就返回,否則就預設為0.0 termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0
) + 1.0)//注意這裡有加1計數操作 } Vectors.sparse(numFeatures, termFrequencies.toSeq) } def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { dataset.map(this.transform) } } class IDF(val minDocFreq: Int){ def this() = this(0) //預設minDocFreq為0,用來過濾文章出現次數過少的分詞 def fit(dataset: RDD[Vector]): IDFModel = { val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(minDocFreq = minDocFreq))( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() new IDFModel(idf) } } private object IDF { /** Document frequency aggregator. */ class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable { /** number of documents */ private var m = 0L /** document frequency vector */ private var df: BDV[Long] = _ def this() = this(0) private def isEmpty: Boolean = m == 0L def add(doc: Vector): this.type = { //add -> 計算分詞在每個分割槽中的文章頻率 if (isEmpty) { df = BDV.zeros(doc.size) } doc match { case SparseVector(size, indices, values) => val nnz = indices.size var k = 0 while (k < nnz) { if (values(k) > 0) { //表示分詞values(k)在該篇文章中出現過 df(indices(k)) += 1L //計數分詞indices(k)出現在多少篇文章中 } k += 1 } case DenseVector(values) => val n = values.size var j = 0 while (j < n) { if (values(j) > 0.0) { //作用和上面一樣,只是在spark中有DenseVector 和 SparseVector兩種向量的區別。 df(j) += 1L } j += 1 } case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") } m += 1L this } /** Merges another. */ def merge(other: DocumentFrequencyAggregator): this.type = { //將各個分割槽聚合到一起 if (!other.isEmpty) { m += other.m if (df == null) { df = other.df.copy } else { df += other.df } } this } /** 返回當前IDF的向量 */ def idf(): Vector = { if (isEmpty) { throw new IllegalStateException("Haven't seen any document yet.") } val n = df.length val inv = new Array[Double](n) var j = 0 while (j < n) { if (df(j) >= minDocFreq) { inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) //計算IDF —— log(D/d(j)) } j += 1 } Vectors.dense(inv) } } } class IDFModel(val idf: Vector) extends Serializable { // idf 裡面儲存的是IDF向量 def transform(dataset: RDD[Vector]): RDD[Vector] = { //dataset裡面儲存的是TF向量 val bcIdf = dataset.context.broadcast(idf) dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v))) } def transform(v: Vector): Vector = IDFModel.transform(idf, v) } private object IDFModel { def transform(idf: Vector, v: Vector): Vector = { // 這裡就是 idf * v (v是TF向量) val n = v.size v match { case SparseVector(size, indices, values) => val nnz = indices.size val newValues = new Array[Double](nnz) var k = 0 while (k < nnz) { newValues(k) = values(k) * idf(indices(k)) //SparseVector 向量下 TF * IDF k += 1 } Vectors.sparse(n, indices, newValues) case DenseVector(values) => val newValues = new Array[Double](n) var j = 0 while (j < n) { newValues(j) = values(j) * idf(j) //DenseVector 向量下 TF * IDF j += 1 } Vectors.dense(newValues) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") } } }