1. 程式人生 > >自定義開發Spark ML機器學習類

自定義開發Spark ML機器學習類

初窺門徑

Spark的MLlib元件內建實現了很多常見的機器學習演算法,包括資料抽取,分類,聚類,關聯分析,協同過濾等等.
然鵝,內建的演算法並不能滿足我們所有的需求,所以我們還是經常需要自定義ML演算法.

MLlib提供的API分為兩類:
- 1.基於DataFrame的API,屬於spark.ml包.
- 2.基於RDD的API, 屬於spark.mllib包.

從Spark 2.0開始,Spark的API全面從RDD轉向DataFrame,MLlib也是如此,官網原話如下:

Announcement: DataFrame-based API is primary API

The MLlib RDD-based API is now in maintenance mode.

所以本文將介紹基於DataFrame的自定義ml類編寫方法.不涉及具體演算法,只講擴充套件ml類的方法.

略知一二

官方文件並沒有介紹如何自定義ml類,所以只有從原始碼入手,看看原始碼裡面是怎麼實現的.

找一個最簡單的內建演算法入手,這個演算法就是內建的分詞器,Tokenizer.

Tokenizer只是簡單的將文字以空白部分進行分割,只適合給英文進行分詞,所以它的實現及其簡短,原始碼如下:

package org.apache.spark.ml.feature

import
org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** * A tokenizer that converts the input string to lowercase and then splits it by white spaces. * * @see
[[RegexTokenizer]] */
@Since("1.2.0") class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("tok")) override protected def createTransformFunc: String => Seq[String] = { _.toLowerCase.split("\\s") } override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) @Since("1.4.1") override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } @Since("1.6.0") object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") override def load(path: String): Tokenizer = super.load(path) }

簡單分析下原始碼:
- Tokenizer繼承了UnaryTransformer類.unary是’一元’的意思,也是說這個類實現的是類似一元函式的功能,一個輸入變數,一個輸出.直接看UnaryTransformer的原始碼註釋:

/**
* :: DeveloperApi ::
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/

DeveloperApi表明這是一個開發級API,開發者可以用,不會有許可權問題(原始碼中有很多private[spark]的類,是不允許外部呼叫的).
註釋的大意就是:這是一個為實現transformers準備的抽象類,以一個欄位(列)為輸入,輸出一個新欄位(列).

所以實際上就是實現一個Transformer,只是這個Transformer有指定的輸入欄位和輸出欄位.

  • UnaryTransformer類中只有兩個抽象方法.
    一個是createTransformFunc,是最核心的方法,這個方法需要返回一個函式,這個函式的引數即Transformer的輸入欄位的值,返回值為Transformer的輸出欄位的值.看看Tokenizer中的實現,就明白了.

另一個是outputDataType,這個方法用來返回輸出欄位的型別.

  • validateInputType方法是用來檢查輸入欄位型別的,看需要實現.

  • Tokenizer混入了DefaultParamsWritable特質,使得自己可以被儲存.
    對應的object Tokenizer伴生物件,用來讀取已儲存的Tokenizer.

  • 值得注意的是,Transformer類是PipelineStage類的子類,所以Transformer的子類,包括我們自定義的,是可以直接用在ML Pipelines中的.這就厲害了,說明自定義的演算法類,可以無縫與內建機器學習演算法打配合,還能利用Pipeline的調優工具(model selection,Cross-Validation等).

初出茅廬

看完原始碼,基本套路已經明瞭,不如動手抄一個,不,敲一個.
依葫蘆畫瓢,實現一個正則提取的Transformer.

import util.matching.Regex

import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._

/**
  * 正則提取器
  * 將匹配指定正則表示式的全部子字串,提取到array[string]中.
  */
class RegexExtractor(override val uid: String)
  extends UnaryTransformer[String, Seq[String], RegexExtractor] {

  def this() = this(Identifiable.randomUID("RegexExtractor"))

  /**
    * 引數:正則表示式
    *
    * @group param
    */
  final val regex = new Param[Regex](this, "RegexExpr", "正則表示式")

  /** @group setParam */
  def setRegexExpr(value: String): this.type = set(regex, new Regex(value))

  override protected def outputDataType: DataType = new ArrayType(StringType, true)

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == DataTypes.StringType,
      s"Input type must be string type but got $inputType."
    )
  }

  override protected def createTransformFunc: String => Seq[String] = {
    parseContent
  }

  /**
    * 資料處理
    */
  private def parseContent(text: String): Seq[String] = {
    if (text == null || text.isEmpty) {
      return Seq.empty[String]
    }
    $(regex).findAllIn(text).toSeq
  }

}

這個類結構與Tokenizer原始碼基本差不多,多用到的Param類,是一個引數的包裝類.
作用是self-contained documentation and optionally default value.
其實就是把引數的值,文件,預設值等屬性組合成一個類,方便呼叫.

比如上面定義的regex引數,就可以用$(regex)這樣的方式直接呼叫.

另外在org.apache.spark.ml.param中有很多內建的Param類,可以直接使用.

同時org.apache.spark.ml.param.shared中有很多輔助引入引數的特質,比如HasInputCols特質,你的自定義Transformer只要混入這個特質就擁有了inputCols引數.不過目前shared中特質的作用域是private[ml],也就是說不能直接引用,而是要copy一份程式碼到自己的專案,並修改作用域才行.
關於這個作用域的問題,有人在spark的jira上提到,提議將其作為DeveloperApi開放出來,我也投了一票表示支援.後來在2017年11月終於resolved,該問題將在Spark2.3.0中解決.詳情戳我

粗懂皮毛

自定義的類寫好了,該怎麼用呢? 當然是跟內建的一樣啦.上栗子:

val regex="nidezhengze"

val tranTitle = new RegexExtractor()
    .setInputCol("title")
    .setOutputCol("title_price_texts")
    .setRegexExpr(regex)

val pipeline = new Pipeline().setStages(Array(
    tranTitle
))

val matched = pipeline.fit(data).transform(data)

打完收功

到這裡,開發簡單Transform的套路已經清楚了,不過這裡實現的功能比較類似於一個UDF,只能對dataset的一個欄位進行處理,而且是逐行處理,並不能根據多行資料進行處理,實現視窗函式類似的功能,而且也沒有涉及模型的輸出.如果要開發更復雜的演算法,甚至進行模型訓練,就需要更深入的瞭解MLlib了,閱讀原始碼是個好途徑.

下回再說.