SumUDAFDemo
package com.XX.udf; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.DoubleWritable; public class SumUDAFDemo extends AbstractGenericUDAFResolver { /** * 获取处理逻辑类 * @param info * @return * @throws SemanticException */ @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException { //判断输入参数是否合法,参数个数,参数类型 if (info.length != 1) { throw new UDFArgumentLengthException("输入参数个数非法,一个参数"); } return new GenericEvaluate(); } //处理逻辑类 public static class GenericEvaluate extends GenericUDAFEvaluator { private PrimitiveObjectInspector input; private DoubleWritable result ; //保存最终结果 private MyAggregationBuffer myAggregationBuffer; //自定义聚合列,保存临时结果 //自定义AggregationBuffer public static class MyAggregationBuffer implements AggregationBuffer { Double sum; } @Override //指定返回类型 public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { super.init(m, parameters); result = new DoubleWritable(0); input = (PrimitiveObjectInspector) parameters[0]; // 指定返回结果类型 return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } @Override //获得一个聚合的缓冲对象,每个map执行一次 public AggregationBuffer getNewAggregationBuffer() throws HiveException { MyAggregationBuffer myAggregationBuffer = new MyAggregationBuffer(); reset(myAggregationBuffer); // 重置聚合值 return myAggregationBuffer; } @Override public void reset(AggregationBuffer agg) throws HiveException { MyAggregationBuffer newAgg = (MyAggregationBuffer) agg; newAgg.sum = 0.0; } @Override // 传入参数值聚合 public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { MyAggregationBuffer myAgg = (MyAggregationBuffer) agg; double inputNum = PrimitiveObjectInspectorUtils.getDouble(parameters[0], input); myAgg.sum += inputNum; } @Override // iterate 输出中间结果 public Object terminatePartial(AggregationBuffer agg) throws HiveException { MyAggregationBuffer newAgg = (MyAggregationBuffer) agg; result.set(newAgg.sum); return result; } @Override // 合并 public void merge(AggregationBuffer agg, Object partial) throws HiveException { double inputNum = PrimitiveObjectInspectorUtils.getDouble(partial, input); MyAggregationBuffer newAgg = (MyAggregationBuffer) agg; newAgg.sum += inputNum; } @Override //输出最终结果 public Object terminate(AggregationBuffer agg) throws HiveException { MyAggregationBuffer aggregationBuffer = (MyAggregationBuffer) agg; result.set(aggregationBuffer.sum); return result; } } }