Hive三种自定义函数

Hive自定义函数包括三种UDF、UDAF、UDTF

UDF(User-Defined-Function) 一进一出

UDAF(User- Defined Aggregation Funcation) 聚集函数,多进一出。Count/max/min

UDTF(User-Defined Table-Generating Functions) 一进多出,如lateral view explore)

  1. 编写函数
  2. 打包上传到Linux
  3. 将jar添加到hive里面 hive> add jar /root/spark_scala_maven.jar
  4. 创建临时函数 hive> create temporary function strLength as 'hiveFun.GetLength';

UDF(一进一出)

如果所操作的数据类型都是基础数据类型,如(Hadoop&Hive 基本writable类型,如Text,IntWritable,LongWriable,DoubleWritable等等)。那么简单的org.apache.hadoop.hive.ql.exec.UDF就可以做到

如果所操作的数据类型是内嵌数据结构,如Map,List和Set,那么要采用org.apache.hadoop.hive.ql.udf.generic.GenericUDF

package hiveFun.UDF;

import org.apache.hadoop.hive.ql.exec.UDF;


//1.继承UDF类
//2.重写evaluate方法
public class UDF_GetLength extends UDF {

    // 实现返回字符串的长度, 方法名不可改
    public int evaluate(String str) { // 输入的数据
        try {
            return str.length();
        } catch (Exception e) {
            return -1;
        }
    }
}

UDAF(多进一出)

package hiveFun.UDAF;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;

@Description(name = "letters", value = "_FUNC_(expr) - 返回该列中所有字符串的字符总数")
public class UDAF_sum extends AbstractGenericUDAFResolver {

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly one argument is expected.");
        }

        ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);

        if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0,
                    "Argument must be PRIMITIVE, but "
                            + oi.getCategory().name()
                            + " was passed.");
        }

        PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;

        if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentTypeException(0,
                    "Argument must be String, but "
                            + inputOI.getPrimitiveCategory().name()
                            + " was passed.");
        }

        return new TotalNumOfLettersEvaluator();
    }

    public static class TotalNumOfLettersEvaluator extends GenericUDAFEvaluator {

        PrimitiveObjectInspector inputOI;
        ObjectInspector outputOI;
        PrimitiveObjectInspector integerOI;

        int total = 0;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {

            assert (parameters.length == 1);
            super.init(m, parameters);

            //map阶段读取sql列,输入为String基础数据格式
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                //其余阶段,输入为Integer基础数据格式
                integerOI = (PrimitiveObjectInspector) parameters[0];
            }

            // 指定各个阶段输出数据格式都为Integer类型
            outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
                    ObjectInspectorFactory.ObjectInspectorOptions.JAVA);
            return outputOI;

        }

        /**
         * 存储当前字符总数的类
         */
        static class LetterSumAgg implements AggregationBuffer {
            int sum = 0;

            void add(int num) {
                sum += num;
            }
        }

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            LetterSumAgg result = new LetterSumAgg();
            return result;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            LetterSumAgg myagg = new LetterSumAgg();
        }

        private boolean warned = false;

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            if (parameters[0] != null) {
                LetterSumAgg myagg = (LetterSumAgg) agg;
                Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
                myagg.add(String.valueOf(p1).length());
            }
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            LetterSumAgg mag = (LetterSumAgg) agg;
            total += mag.sum;
            return total;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial != null) {

                LetterSumAgg myagg1 = (LetterSumAgg) agg;

                Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);

                LetterSumAgg myagg2 = new LetterSumAgg();

                myagg2.add(partialSum);
                myagg1.add(myagg2.sum);
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            LetterSumAgg mag = (LetterSumAgg) agg;
            total = mag.sum;
            return mag.sum;
        }

    }
}

UDTF(一进多出)

package hiveFun.UDTF;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

import java.util.ArrayList;
import java.util.List;

public class StrSplit extends GenericUDTF {

    private List<String> dataList = new ArrayList<>();

    /**
     * 定义输出的字段名和字段的类型
     */
    @Override
    public StructObjectInspector initialize(StructObjectInspector argOIs) throws UDFArgumentException {

        // 定义输出的字段名
        List<String> filedsNames = new ArrayList<>();
        filedsNames.add("word");

        // 定义对应字段的数据类型
        List<ObjectInspector> filedOIS = new ArrayList<>();
        // String类型
        filedOIS.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);


        return ObjectInspectorFactory.getStandardStructObjectInspector(filedsNames,filedOIS);
    }

    /**
     * StrSplit(filedName,",") 循环调用的方法
     */
    @Override
    public void process(Object[] objects) throws HiveException {

        // 转换filedName字段的数据类型
        String data = objects[0].toString();
        // 转换","的数据类型
        String splitKey = objects[1].toString();

        // 切分字符串
        String[] words = data.split(splitKey);

        // 遍历写出字符
        for (String word : words) {
            dataList.clear();
            dataList.add(word);

            // 将数据写出, 注意要和filedOIS的数据类型相同
            forward(dataList);
        }

    }

    // 最后执行的方法,关闭操作
    @Override
    public void close() throws HiveException {

    }
}
-- 添加jar包
hive (default)> add jar /root/spark_scala_maven.jar;

-- 创建函数
hive (default)> create temporary function splitStr as 'hiveFun.UDTF.StrSplit';
OK
Time taken: 0.007 seconds
-- 使用函数
hive (default)> select splitStr('a,b,c,d,e',',');
OK
word
a
b
c
d
e
posted @ 2019-10-23 16:16  会走的树  阅读(3397)  评论(0编辑  收藏  举报