1. 程式人生 > >大資料學習之路90-sparkSQL自定義聚合函式UDAF

大資料學習之路90-sparkSQL自定義聚合函式UDAF

什麼是UDAF?就是輸入N行得到一個結果,屬於聚合類的。

接下來我們就寫一個求幾何平均數的一個自定義聚合函式的例子

我們從開頭寫起,先來看看需要進行計算的數如何產生:

package com.test.SparkSQL

import java.lang

import org.apache.spark.sql.{Dataset, SparkSession}

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()
    val ds: Dataset[lang.Long] = spark.range(1,10)
    ds.show()
  }
}

生成結果:

接下來我們使用自定義聚合函式計算幾何平均數:

package com.test.SparkSQL

import java.lang

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, Row, SparkSession, types}

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()
    val ds: Dataset[lang.Long] = spark.range(1,10)
    //ds.show()
    ds.createTempView("v_num")
    val gm = new GeometriMean
    spark.udf.register("gm",gm)
    spark.sql("select gm(id) as gm from v_num").show()
  }
}

class GeometriMean extends UserDefinedAggregateFunction{
  //定義輸入資料的型別
  override def inputSchema: StructType = StructType(List(StructField("value",DoubleType)))
  //定義儲存聚合運算時產生的中間資料結果的型別
  override def bufferSchema: StructType = StructType(
    List(
      StructField("count",LongType),
      StructField("product",DoubleType)
    )
  )
  //表名了UDAF函式的返回值型別
  override def dataType: DataType = DoubleType
  //用以標記針對給定的一組輸入,UDAF是否總是生成相同的結果
  override def deterministic: Boolean = true
  //對聚合運算中間結果的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 1.0
  }
  //每處理一條資料都要執行update,相當於區域性計算
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
       buffer(0) = buffer.getAs[Long](0)+1
       buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
  }
  //負責合併兩個聚合運算的buffer,再將其儲存到MutableAggregationBuffer,相當於全域性計算
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
  }
  //完成對聚合Buffer值的運算,得到最後的結果
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(1),1.toDouble/buffer.getLong(0))
  }
}

執行結果: