1. 程式人生 > >自定義UDAF(多對一)

自定義UDAF(多對一)

package day01

import org.apache.spark.sql.{Row, types}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
  * 自定義一個聚合方法
  * 首先要定義一個類繼承UserDefinedAggregateFunction
  * 重寫8個方法
  *
  */
class GeometricMean extends  UserDefinedAggregateFunction{
  //UDAF與DataFrame列有關的輸入樣式,StructField的名字並沒有特別要求,完全可以認為是兩個內部結構的列名站位符
  //至於UDAF具體要操作DataFrame的那個列,取決於呼叫者,但前提是資料型別必須符合事先的設定,如這裡的Double
  override def inputSchema: StructType = StructType(List(StructField("value",DoubleType)))
  //定義儲存聚合運算時產生的中間資料結果的Schema
  override def bufferSchema: StructType = StructType(List(
   //參與運算的個數
    StructField("count",LongType),
    //參與運算的乘積
    StructField("product",DoubleType)
  ))
  //標明瞭UDAF函式的返回值型別
  override def dataType: DataType = DoubleType
  //用以標記針對給定的一組輸入,UDAF是否總是生成相同的結果
  override def deterministic: Boolean = true
  //對聚合運算中間結果的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
  //參與運算數字的個數
    buffer(0)=0L
    //參與相乘的值
    buffer(1)=1.0
  }
  //每處理一條資料都要執行update
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
   //區域性計算
    buffer(0) =buffer.getAs[Long](0) +1
    buffer(1) =buffer.getAs[Double](1) * input.getAs[Double](0)
  }
  //負責合併兩個聚合運算的buffer,再將其儲存到MutableAggregationBuffer中
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
   //全域性  累加    累乘
    buffer1(0)=buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1)=buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
  }
  //完成對聚合Buffer值得運算,得到最後的結果
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
  }
}
object GeometricMean{
  def main(args: Array[String]): Unit = {
   // val r =Math.pow(1*2*3*4*5*6*7*8*9,1.toDouble/9)
   val r =Math.pow(3,1.toDouble/2)
    println(r)
  }
}

 

package day01

import java.lang

import org.apache.spark.sql.{Dataset, SparkSession}


object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[*]")
      .getOrCreate()
     //df.show()列叫id
    val df: Dataset[lang.Long] = spark.range(1,10)

    val gm = new GeometricMean
      //寫sql需要註冊檢視
    // df.createTempView("v_num")
    spark.udf.register("gm",gm)
 //   spark.sql("SELECT gm(id) as gm from v_num").show()

    //不用檢視來弄,直接使用運算元
   // df.select(expr("gm(id) as GeometricMean")).show()
   // df.groupBy().agg(gm(col("id")).as("GeometricMean")).show
  }
}

SparkSQL的自定義函式
UDF 呼叫函式式輸入一行,返回一個值, 1->1 substring
UDAF 呼叫函式時輸入N行,返回一個值 N-> 1 count(*)

使用UDFs之前要先註冊
spark.udf.register("ip2Long",(ip:String)=>{
    //返回Long型別
})

spark.udf.register("gn" new UserDefineAggregateFunction(){
    //重新八個方法
})