1. 程式人生 > >14.Spark SQL:UDAF自定義聚合函式實戰

14.Spark SQL:UDAF自定義聚合函式實戰

UDAF自定義函式實戰

UDAFUser Defined Aggregate Function。使用者自定義聚合函式。是Spark 1.5.x引入的最新特性。

UDF,其實更多的是針對單行輸入,返回一個輸出

這裡的UDAF,則可以針對多行輸入,進行聚合計算,返回一個輸出,功能更加強大

package cn.spark.study.sql
 
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.IntegerType
 
/**
 * @author Administrator
 */
class StringCount extends UserDefinedAggregateFunction {  
  
  // inputSchema,指的是,輸入資料的型別
  def inputSchema: StructType = {
    StructType(Array(StructField("str", StringType, true)))   
  }
  
  // bufferSchema,指的是,中間進行聚合時,所處理的資料的型別
  def bufferSchema: StructType = {
    StructType(Array(StructField("count", IntegerType, true)))   
  }
  
  // dataType,指的是,函式返回值的型別
  def dataType: DataType = {
    IntegerType
  }
  
  def deterministic: Boolean = {
    true
  }
 
  // 為每個分組的資料執行初始化操作
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }
  
  // 指的是,每個分組,有新的值進來的時候,如何進行分組對應的聚合值的計算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }
  
  // 由於Spark是分散式的,所以一個分組的資料,可能會在不同的節點上進行區域性聚合,就是update
  // 但是,最後一個分組,在各個節點上的聚合值,要進行merge,也就是合併
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)  
  }
  
  // 最後,指的是,一個分組的聚合值,如何通過中間的快取聚合值,最後返回一個最終的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)    
  }
  
}
package cn.spark.study.sql
 
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
 
/**
 * UDAF:User Defined Aggregate Function。使用者自定義聚合函式。是Spark 1.5.x引入的最新特性。
 * UDF,其實更多的是針對單行輸入,返回一個輸出
 * 這裡的UDAF,則可以針對多行輸入,進行聚合計算,返回一個輸出,功能更加強大
 *
 * @author Administrator
 */
 
object UDAF {
  
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
        .setMaster("local")
        .setAppName("UDAF")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
  
    // 構造模擬資料
    val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")  
    val namesRDD = sc.parallelize(names, 5)
    val namesRowRDD = namesRDD.map { name => Row(name) }
    val structType = StructType(Array(StructField("name", StringType, true)))  
    val namesDF = sqlContext.createDataFrame(namesRowRDD, structType)
    
    // 註冊一張names表
    namesDF.registerTempTable("names")  
    
    // 定義和註冊自定義函式
    // 定義函式:自己寫匿名函式
    // 註冊函式:SQLContext.udf.register()
    sqlContext.udf.register("strCount", new StringCount)
    
    // 使用自定義函式
    sqlContext.sql("select name,strCount(name) from names group by name")  
        .collect()
        .foreach(println)  
  }
}

本地執行結果: