14.Spark SQL:UDAF自定義聚合函式實戰
阿新 • • 發佈:2019-01-06
UDAF自定義函式實戰
UDAF:User Defined Aggregate Function。使用者自定義聚合函式。是Spark 1.5.x引入的最新特性。
UDF,其實更多的是針對單行輸入,返回一個輸出
這裡的UDAF,則可以針對多行輸入,進行聚合計算,返回一個輸出,功能更加強大
package cn.spark.study.sql
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.IntegerType
/**
* @author Administrator
*/
class StringCount extends UserDefinedAggregateFunction {
// inputSchema,指的是,輸入資料的型別
def inputSchema: StructType = {
StructType(Array(StructField("str", StringType, true)))
}
// bufferSchema,指的是,中間進行聚合時,所處理的資料的型別
def bufferSchema: StructType = {
StructType(Array(StructField("count", IntegerType, true)))
}
// dataType,指的是,函式返回值的型別
def dataType: DataType = {
IntegerType
}
def deterministic: Boolean = {
true
}
// 為每個分組的資料執行初始化操作
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
}
// 指的是,每個分組,有新的值進來的時候,如何進行分組對應的聚合值的計算
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + 1
}
// 由於Spark是分散式的,所以一個分組的資料,可能會在不同的節點上進行區域性聚合,就是update
// 但是,最後一個分組,在各個節點上的聚合值,要進行merge,也就是合併
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
// 最後,指的是,一個分組的聚合值,如何通過中間的快取聚合值,最後返回一個最終的聚合值
def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
}
package cn.spark.study.sql
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
/**
* UDAF:User Defined Aggregate Function。使用者自定義聚合函式。是Spark 1.5.x引入的最新特性。
* UDF,其實更多的是針對單行輸入,返回一個輸出
* 這裡的UDAF,則可以針對多行輸入,進行聚合計算,返回一個輸出,功能更加強大
*
* @author Administrator
*/
object UDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.setMaster("local")
.setAppName("UDAF")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
// 構造模擬資料
val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")
val namesRDD = sc.parallelize(names, 5)
val namesRowRDD = namesRDD.map { name => Row(name) }
val structType = StructType(Array(StructField("name", StringType, true)))
val namesDF = sqlContext.createDataFrame(namesRowRDD, structType)
// 註冊一張names表
namesDF.registerTempTable("names")
// 定義和註冊自定義函式
// 定義函式:自己寫匿名函式
// 註冊函式:SQLContext.udf.register()
sqlContext.udf.register("strCount", new StringCount)
// 使用自定義函式
sqlContext.sql("select name,strCount(name) from names group by name")
.collect()
.foreach(println)
}
}
本地執行結果: