1. 程式人生 > >SparkSQL 如何自定義函式

SparkSQL 如何自定義函式

 

1. SparkSql如何自定義函式

2. 示例:Average

3. 型別安全的自定義函式

1. SparkSql如何自定義函式?

  spark中我們定義一個函式,需要繼承 UserDefinedAggregateFunction這個抽象類,實現這個抽象類中所定義的方法,這是一個模板設計模式? 我只要實現抽象類的中方法,具體的所有的計算步驟由內部完成。而我們可以看一下UserDefinedAggregateFunction這個抽象類。

package org.apache.spark.sql.expressions
@org.apache.spark.annotation.InterfaceStability.Stable
abstract class UserDefinedAggregateFunction() extends scala.AnyRef with scala.Serializable { def inputSchema : org.apache.spark.sql.types.StructType def bufferSchema : org.apache.spark.sql.types.StructType def dataType : org.apache.spark.sql.types.DataType def deterministic : scala.Boolean def initialize(buffer : org.apache.spark.sql.expressions.MutableAggregationBuffer) : scala.Unit def update(buffer : org.apache.spark.sql.expressions.MutableAggregationBuffer, input : org.apache.spark.sql.Row) : scala.Unit def merge(buffer1 : org.apache.spark.sql.expressions.MutableAggregationBuffer, buffer2 : org.apache.spark.sql.Row) : scala.Unit def evaluate(buffer : org.apache.spark.sql.Row) : scala.Any @scala.annotation.varargs def apply(exprs : org.apache.spark.sql.Column*) : org.apache.spark.sql.Column = { /* compiled code */ } @scala.annotation.varargs def distinct(exprs : org.apache.spark.sql.Column*) : org.apache.spark.sql.Column = { /* compiled code */ } }

  也就是說對於這幾個函式,我們只要依次實現他們的功能,其餘的交給spark就可以了。

  

2. 自定義Average函式

  首先新建一個Object類MyAvage類,繼承UserDefinedAggregateFunction。下面對每一個函式的實現進行解釋。

  def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)

  這個規定了輸入資料的資料結構

 

def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }

  這個規定了快取區的資料結構

 

  def dataType: DataType = DoubleType

  這個規定了返回值的資料型別

 

def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }  

進行初始化,這裡要說明一下,官網中提到:

// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
  // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
  // the opportunity to update its values. Note that arrays and maps inside the buffer are still
  // immutable.

這裡翻譯一下:

我們為我們的緩衝區設定初始值,我們不僅可以設定數字,還可以使用index getBoolen等去改變他的值,但是我們需要知道的是,在這個緩衝區中,陣列和map依然是不可變的。

其實最後一句我也是不太明白,等我以後如果能研究並理解這句話,再回來補充吧。

 

def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(0)) {
      buffer(0) = buffer.getLong(0) + input.getLong(0)
      buffer(1) = buffer.getLong(1) + 1
    }
  }

  這個是重要的update函式,對於平均值,我們可以不斷迭代輸入的值進行累加。buffer(0)統計總和,buffer(1)統計長度。

 

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  在做完update後spark 需要將結果進行merge到我們的區域,因此有一個merge 進行覆蓋buffer

 

  def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)

  這是將最終的結果進行計算。

 

在寫完這個類以後我們在我們的sparksession裡面進行編寫測試案例。

spark.sparkContext.textFile("file:///Users/4pa/Desktop/people.txt")
      .map(_.split(","))
      .map(agg=>Person(agg(0),agg(1).trim.toInt))
      .toDF().createOrReplaceTempView("people")
spark.udf.register("myAverage",Myaverage)
val udfRes = spark.sql("select name,myAverage(age) as avgAge from people group by name")
udfRes.show()

  

3. 型別安全的自定義函式

從上面我們可以看出來,這種自定義函式不是型別安全的,因此能否實現一個安全的自定義函式呢?

個人覺得最好的例子還是官網給的例子,具體的解釋都已經給了出來,思路其實和上面是一樣的,只不過定義了兩個caseclass,用於型別的驗證。

case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Employee, Average, Double] {
  // 初始化
  def zero: Average = Average(0L, 0L)
  // 這個其實有點map-reduce的意思,只不過是對一個類的reduce,第一個值是和,第二個是總數
  def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
  }
  // 實現緩衝區的一個覆蓋
  def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  // 計算最終數值
  def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  // Specifies the Encoder for the intermediate value type
  def bufferEncoder: Encoder[Average] = Encoders.product
  // 指定返回型別
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}