1. 程式人生 > >記錄一個自己寫的hiveUDAF

記錄一個自己寫的hiveUDAF

這是一個我自己參考網站寫的UDAF,期間各種bug,終於讓我跑通了,作用是輸入表字段名稱,輸出欄位的統計總行數,為空行數,以及top十條去重後的樣例資料,方法說明都有標註,以下是程式碼貼圖:

package com.zh.hive;



import net.sf.json.JSONObject;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.eclipse.jetty.util.ajax.JSON;


import java.util.*;


public class QcUdf extends AbstractGenericUDAFResolver {

    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameter) throws SemanticException {
        ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameter[0]);
        PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
        return new GenericUDAFHistogramNumericEvaluator();
    }

    public static class GenericUDAFHistogramNumericEvaluator extends GenericUDAFEvaluator {
        // UDAF logic goes here!
        PrimitiveObjectInspector inputOI;
        ObjectInspector outputOI;
        PrimitiveObjectInspector integerOI;
        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            super.init(m, parameters);

            //map階段讀取sql列,輸入為String基礎資料格式
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                //其餘階段,輸入為String基礎資料格式
                integerOI = (PrimitiveObjectInspector) parameters[0];
            }
            // 指定各個階段輸出資料格式都為String型別
            outputOI = ObjectInspectorFactory.getReflectionObjectInspector(String.class,
                    ObjectInspectorFactory.ObjectInspectorOptions.JAVA);
            return outputOI;
        }
        /**
         * 儲存當前字元總數的類
         */
        static class LetterSumAgg implements AggregationBuffer {
            int sum = 0;
            int count1 = 0;
            Map<String,Integer> map =new HashMap<String,Integer>();
            void put(String str){//放進去一個欄位值
                str = str.trim();
                if(str!=null||str!=""){
                    if (map.get(str)!=null){
                        int org=map.get(str)+1;
                        map.put(str,org);
                    }else{
                        map.put(str,1);
                    }
                }else{
                    map.put("null_key",1);
                }
            }
            void put(Map<String,Integer> target_map) {//合併兩個map
                Iterator<Map.Entry<String,Integer>> target = target_map.entrySet().iterator();
                while (target.hasNext()) {
                    Map.Entry<String,Integer> next = target.next();
                    String key = next.getKey();
                    if(map.get(key)!=null){
                        map.put(key,map.get(key)+target_map.get(key));
                    } else{
                        map.put(key,target_map.get(key));
                    }
                }
            }
            void add(int num,int count){
                sum += num;
                count1 += count;
            }
            String getTop10(){
                List <String> list = new ArrayList<String>();
                String str ="";
                for(Map.Entry entry:map.entrySet()){
                    list.add(entry.getValue().toString());
                }
                Collections.sort(list);
                if(list.size()>10){
                    int count = 0;
                    for (int i=list.size()-1;i>list.size()-11;i--){
                        if (count<=10) {
                            for (Map.Entry entry : map.entrySet()) {
                                if (list.get(i).equals(entry.getValue().toString())) {
                                    count++;
                                    if (count <= 10) {
                                        str += entry.getKey().toString().replace("\n","").replace("\t","").replace("|","") + "@" + entry.getValue().toString() + ",";
                                        map.put(entry.getKey().toString(), 0);
                                    }else{ break;}
                                }
                            }
                        }
                    }
                }else{
                    for(Map.Entry entry:map.entrySet()){
                        str += entry.getKey().toString().replace("null_key","null")+"@"+entry.getValue().toString()+",";
                    }
                }
                return str;
            }
        }
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            LetterSumAgg result = new LetterSumAgg();
            return result;
        }

        public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
            LetterSumAgg myagg = new LetterSumAgg();
        }
        private boolean warned = false;

        public void iterate(AggregationBuffer aggregationBuffer, Object[] objects) throws HiveException {//邏輯存放地
            assert (objects.length == 1);
            LetterSumAgg myagg = (LetterSumAgg) aggregationBuffer;
            if(myagg==null){
                myagg = new LetterSumAgg();
            }
            if (objects[0] != null&&objects[0].toString().toLowerCase().trim() !="null"&&objects[0].toString().trim() !="") {
                myagg.put(objects[0].toString());
                myagg.add(1,0);//統計總行數
            }else{
                myagg.put("null_key");
                myagg.add(1,1);//統計總行數
            }
        }

        public String terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {//單機合併
            LetterSumAgg agg = new LetterSumAgg();
            LetterSumAgg myagg = (LetterSumAgg)aggregationBuffer;
            if(myagg==null){
                myagg = new LetterSumAgg();
            }
            agg.sum += myagg.sum;
            agg.count1 += myagg.count1;
            agg.put(myagg.map);
            JSONObject jsonObject=null;
            if (agg.map!=null){
                 jsonObject = JSONObject.fromObject(agg.map);
            }
          //  JSONObject jsonObject = JSONObject.fromObject(agg.map);
            return agg.sum+"#@"+agg.count1+"#@"+jsonObject;
        }

        public void merge(AggregationBuffer aggregationBuffer, Object o) throws HiveException {//叢集合併
            if ( o!= null) {
                LetterSumAgg myagg1 = (LetterSumAgg) aggregationBuffer;
                String agg = (String) integerOI.getPrimitiveJavaObject(o);
               String result[] = agg.split("#@");
               if (result[2]!=null) {
                   Map maps = (Map) JSON.parse(result[2]);
                   myagg1.put(maps);
               }
                myagg1.add(Integer.parseInt(result[0]),Integer.parseInt(result[1]));
            }
        }

        public Object terminate(AggregationBuffer aggregationBuffer) throws HiveException {//複製最終結果
            LetterSumAgg myagg = (LetterSumAgg) aggregationBuffer;
            return myagg.sum+"|"+myagg.count1+"|"+myagg.getTop10();
        }

    }

}

各位朋友使用請直接copy即可。附上maven依賴

<dependencies>
        <dependency>
            <groupId>org.apache.hive</groupId>
            <artifactId>hive-jdbc</artifactId>
            <version>2.1.1</version>
        </dependency>
        <dependency>
            <groupId>org.apache.hive</groupId>
            <artifactId>hive-exec</artifactId>
            <version>2.1.1</version>
        </dependency>
</dependencies>

大功告成,測試結果樣例如下:

38386|0|[3522963,  3383561,  3517824,  3505051,  3037673,  3523778,  3300084,  3483628,  3525325,  3514324]

執行程式碼如下:

use databases_name;
add jar /home/zhangheng/hive.jar;
create temporary function tj as 'com.zh.hive.QcUdf';
select tj(c1) ,tj(c2),tj(c3) from table;