1. 程式人生 > >SparkSQL自定義函式(實現幾何平均數)

SparkSQL自定義函式(實現幾何平均數)

SparkSQL-自定義聚合函式 (實現幾何平均數)


->建立SparkSessionparkSession

->建立自定義函式
    -1、繼承UserDefinedAggregateFunction
    -2、重寫下面的方法    
        inputSchema        -輸入資料的型別
        bufferSchema       -產生中間結果的資料型別
        dataType           -最終返回的結果型別
        deterministic      -確保一致性
        initialize         -指定初始值
        update             -每有一條資料參與運算就更新一下中間結果
        merge              -全域性聚合
        evaluate           -計算最終結果
        
    !:StructField        -哪些列,啥型別
->例項化自定義函式,並註冊自定義函式(spark.udf.register)

 

程式碼:

object Geometric {
  def main(args: Array[String]): Unit = {
    //建立sparkSession
    val sparkSession: SparkSession = SparkSession.builder().appName("Geometric").master("local[*]").getOrCreate()
    //造資料 1~10
    val range: Dataset[lang.Long] = sparkSession.range(1, 11)

    //例項化Geom類
    val geomean = new Geom
    //註冊檢視
    range.createTempView("v_range")
    //註冊自定義函式
    sparkSession.udf.register("ge", geomean)
    //執行sparkSql語句
    val res: DataFrame = sparkSession.sql("select ge(id) result from v_range")

    res.show()

    sparkSession.stop()
  }
}

class Geom extends UserDefinedAggregateFunction {
  //輸入型別
  override def inputSchema: StructType = StructType(List(StructField("value", DoubleType)))

  //中間資料
  override def bufferSchema: StructType = StructType(List(
    StructField("product", DoubleType),
    StructField("counts", LongType)
  ))

  //最終返回結果型別
  override def dataType: DataType = DoubleType

  //確保一致性
  override def deterministic: Boolean = true

  //指定初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 1.0
    buffer(1) = 0L
  }

  //每有一條資料參與運算就更新一下中間結果(update相當於在每一個分割槽中的運算)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getDouble(0) * input.getDouble(0)
    buffer(1) = buffer.getLong(1) + 1L
  }

  //全域性聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) * buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)

  }

  override def evaluate(buffer: Row): Double = {

    math.pow(buffer.getDouble(0), 1.toDouble / buffer.getLong(1))

  }
}