import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
public class GenericUDAFAveragePlus extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
if (null != info && info.length == 1) {
if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("该函数该函数只能接收接收简单类型的参数!");
}
PrimitiveTypeInfo pti = (PrimitiveTypeInfo) info[0];
if (!pti.getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.LONG)) {
throw new UDFArgumentException("该函数只能接收Long类型的参数");
}
} else {
throw new UDFArgumentException("该函数需要接收参数!并且只能传递一个参数!");
}
return new MyGenericUDAFEvaluator();
}
private static class MyGenericUDAFEvaluator extends GenericUDAFEvaluator {
private static class MyAggregationBuffer extends AbstractAggregationBuffer{
private Double sum = 0D;
private Long count = 0L;
public Double getSum() {
return sum;
}
public void setSum(Double sum) {
this.sum = sum;
}
public Long getCount() {
return count;
}
public void setCount(Long count) {
this.count = count;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
printMode("getNewAggregationBuffer");
return new MyAggregationBuffer();
}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
printMode("init");
super.init(m, parameters);
if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
List<String> structFieldNames = new ArrayList<String>();
List<ObjectInspector> structFieldObjectInspectors = new ArrayList<ObjectInspector>();
structFieldNames.add("sum");
structFieldNames.add("count");
structFieldObjectInspectors.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
structFieldObjectInspectors.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldObjectInspectors);
}else {
return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
printMode("reset");
((MyAggregationBuffer)agg).setCount(0L);
((MyAggregationBuffer)agg).setSum(0D);
}
private Long p = 0L;
private Long current_count = 0L;
private Double current_sum = 0D;
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
printMode("iterate");
p = Long.parseLong(String.valueOf(parameters[0]).trim());
MyAggregationBuffer ab = (MyAggregationBuffer) agg;
current_sum += p;
current_count++;
ab.setCount(current_count);
ab.setSum(current_sum);
}
private Object[] mapout = {new DoubleWritable(),new LongWritable()};
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
printMode("terminatePartial");
MyAggregationBuffer ab = (MyAggregationBuffer) agg;
((DoubleWritable)mapout[0]).set(ab.getSum());
((LongWritable)mapout[1]).set(ab.getCount());
return mapout;
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
printMode("merge");
if (partial instanceof LazyBinaryStruct) {
LazyBinaryStruct lbs = (LazyBinaryStruct) partial;
DoubleWritable sum = (DoubleWritable) lbs.getField(0);
LongWritable count = (LongWritable) lbs.getField(1);
MyAggregationBuffer ab = (MyAggregationBuffer) agg;
ab.setCount(ab.getCount() + count.get());
ab.setSum(ab.getSum() + sum.get());
}
}
private Text reduceout = new Text();
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
printMode("terminate");
MyAggregationBuffer ab = (MyAggregationBuffer) agg;
Double sum = ab.getSum();
Long count = ab.getCount();
Double avg = sum/count;
DecimalFormat df = new DecimalFormat("###,###.00");
reduceout.set(df.format(avg));
return reduceout;
}
public void printMode(String mname){
System.out.println("=================================== "+mname+" is Running! ================================");
}
}
}