SumUDAFDemo

package com.XX.udf;

import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;

public class SumUDAFDemo extends AbstractGenericUDAFResolver {


    /**
     * 获取处理逻辑类
     * @param info
     * @return
     * @throws SemanticException
     */
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
        //判断输入参数是否合法,参数个数,参数类型
        if (info.length != 1) {
            throw new UDFArgumentLengthException("输入参数个数非法,一个参数");
        }

        return new GenericEvaluate();
    }


    //处理逻辑类
    public static class GenericEvaluate extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector input;
        private DoubleWritable result ;                   //保存最终结果
        private MyAggregationBuffer myAggregationBuffer;  //自定义聚合列,保存临时结果

        //自定义AggregationBuffer
        public static class MyAggregationBuffer implements AggregationBuffer {
            Double sum;
        }

        @Override  //指定返回类型
        public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);
            result = new DoubleWritable(0);
            input = (PrimitiveObjectInspector) parameters[0];
            // 指定返回结果类型
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        @Override   //获得一个聚合的缓冲对象,每个map执行一次
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            MyAggregationBuffer myAggregationBuffer = new MyAggregationBuffer();
            reset(myAggregationBuffer);  // 重置聚合值
            return myAggregationBuffer;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            MyAggregationBuffer newAgg = (MyAggregationBuffer) agg;
            newAgg.sum = 0.0;
        }

        @Override  // 传入参数值聚合
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            MyAggregationBuffer myAgg = (MyAggregationBuffer) agg;
            double inputNum = PrimitiveObjectInspectorUtils.getDouble(parameters[0], input);
            myAgg.sum += inputNum;
        }

        @Override  // iterate 输出中间结果
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            MyAggregationBuffer newAgg = (MyAggregationBuffer) agg;
            result.set(newAgg.sum);
            return result;
        }

        @Override  // 合并
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            double inputNum = PrimitiveObjectInspectorUtils.getDouble(partial, input);
            MyAggregationBuffer newAgg = (MyAggregationBuffer) agg;
            newAgg.sum += inputNum;
        }

        @Override  //输出最终结果
        public Object terminate(AggregationBuffer agg) throws HiveException {
            MyAggregationBuffer aggregationBuffer = (MyAggregationBuffer) agg;
            result.set(aggregationBuffer.sum);
            return result;
        }
    }
}

 

posted @ 2019-05-17 09:53  问题不大1  阅读(431)  评论(0编辑  收藏  举报