1. 程式人生 > >Spark之hive的UDF自定義函式

Spark之hive的UDF自定義函式

1.簡單的

package com.llcc.sparkSql.MyTimeSort

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.hive.HiveContext

object UDFDemo1 {

  def main(args:Array[String]):Unit = {
    val conf = new SparkConf().setAppName("aa")
    val sc = new SparkContext(conf)
    val hiveContext = new HiveContext(sc)
    hiveContext.udf
.register("strlen",(str:String) => { if(str != null){ str.length() }else{ 0 } }) hiveContext.sql("select strlen(category) from xtwy.worker" ).show() } }

這裡寫圖片描述

2. 繼承 UserDefinedAggregateFunction

package com.llcc.sparkSql.MyTimeSort

import org.apache.spark.{SparkConf, SparkContext}
import
org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types._ object UDFDemo extends UserDefinedAggregateFunction{ /** * 定義輸入資料的型別,因為這裡我們只有一列資料,但是這裡要求一個集合,所以要加上Nil * 這裡我們要計算的是hive中的salary欄位 * @return
*/ override def inputSchema: StructType = StructType( StructField("salary",DoubleType,true)::Nil ) /** * 定義快取欄位的名字和資料型別 * @return */ override def bufferSchema: StructType = StructType( StructField("total",DoubleType,true):: StructField("count",IntegerType,true)::Nil ) override def dataType: DataType = DoubleType override def deterministic: Boolean = true /** * 對參與的值進行初始化 * @param buffer */ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0,0.0) buffer.update(1,0) } override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val total = buffer.getDouble(0) val count = buffer.getInt(1) val currentSalary = input.getDouble(0) buffer.update(0,total+currentSalary) buffer.update(1,count+1) } override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { val total1 = buffer1.getDouble(0) val count1 = buffer1.getInt(1) val total2 = buffer2.getDouble(0) val count2 = buffer2.getInt(1) buffer1.update(0,total1+total2) buffer1.update(1,count1+count2) } override def evaluate(buffer: Row): Any = { val total = buffer.getDouble(0) val count = buffer.getInt(1) total/count } def main(args:Array[String]):Unit = { val conf = new SparkConf().setAppName("aa") val sc = new SparkContext(conf) val hiveContext = new HiveContext(sc) hiveContext.udf.register("salary_avg",UDFDemo) hiveContext.sql("select salary_avg(salary) from xtwy.worker" ).show() } }

原始資料

這裡寫圖片描述

求薪水的平均值,可以看到是正確的

這裡寫圖片描述