Hive三种自定义函数
Hive自定义函数包括三种UDF、UDAF、UDTF
UDF(User-Defined-Function) 一进一出
UDAF(User- Defined Aggregation Funcation) 聚集函数,多进一出。Count/max/min
UDTF(User-Defined Table-Generating Functions) 一进多出,如lateral view explore)
- 编写函数
- 打包上传到Linux
- 将jar添加到hive里面 hive> add jar /root/spark_scala_maven.jar
- 创建临时函数 hive> create temporary function strLength as 'hiveFun.GetLength';
UDF(一进一出)
如果所操作的数据类型都是基础数据类型,如(Hadoop&Hive 基本writable类型,如Text,IntWritable,LongWriable,DoubleWritable等等)。那么简单的org.apache.hadoop.hive.ql.exec.UDF就可以做到
如果所操作的数据类型是内嵌数据结构,如Map,List和Set,那么要采用org.apache.hadoop.hive.ql.udf.generic.GenericUDF
package hiveFun.UDF;
import org.apache.hadoop.hive.ql.exec.UDF;
//1.继承UDF类
//2.重写evaluate方法
public class UDF_GetLength extends UDF {
// 实现返回字符串的长度, 方法名不可改
public int evaluate(String str) { // 输入的数据
try {
return str.length();
} catch (Exception e) {
return -1;
}
}
}
UDAF(多进一出)
package hiveFun.UDAF;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
@Description(name = "letters", value = "_FUNC_(expr) - 返回该列中所有字符串的字符总数")
public class UDAF_sum extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
if (oi.getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Argument must be PRIMITIVE, but "
+ oi.getCategory().name()
+ " was passed.");
}
PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
if (inputOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
throw new UDFArgumentTypeException(0,
"Argument must be String, but "
+ inputOI.getPrimitiveCategory().name()
+ " was passed.");
}
return new TotalNumOfLettersEvaluator();
}
public static class TotalNumOfLettersEvaluator extends GenericUDAFEvaluator {
PrimitiveObjectInspector inputOI;
ObjectInspector outputOI;
PrimitiveObjectInspector integerOI;
int total = 0;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(m, parameters);
//map阶段读取sql列,输入为String基础数据格式
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
//其余阶段,输入为Integer基础数据格式
integerOI = (PrimitiveObjectInspector) parameters[0];
}
// 指定各个阶段输出数据格式都为Integer类型
outputOI = ObjectInspectorFactory.getReflectionObjectInspector(Integer.class,
ObjectInspectorFactory.ObjectInspectorOptions.JAVA);
return outputOI;
}
/**
* 存储当前字符总数的类
*/
static class LetterSumAgg implements AggregationBuffer {
int sum = 0;
void add(int num) {
sum += num;
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
LetterSumAgg result = new LetterSumAgg();
return result;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
LetterSumAgg myagg = new LetterSumAgg();
}
private boolean warned = false;
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
if (parameters[0] != null) {
LetterSumAgg myagg = (LetterSumAgg) agg;
Object p1 = ((PrimitiveObjectInspector) inputOI).getPrimitiveJavaObject(parameters[0]);
myagg.add(String.valueOf(p1).length());
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
LetterSumAgg mag = (LetterSumAgg) agg;
total += mag.sum;
return total;
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if (partial != null) {
LetterSumAgg myagg1 = (LetterSumAgg) agg;
Integer partialSum = (Integer) integerOI.getPrimitiveJavaObject(partial);
LetterSumAgg myagg2 = new LetterSumAgg();
myagg2.add(partialSum);
myagg1.add(myagg2.sum);
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
LetterSumAgg mag = (LetterSumAgg) agg;
total = mag.sum;
return mag.sum;
}
}
}
UDTF(一进多出)
package hiveFun.UDTF;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import java.util.ArrayList;
import java.util.List;
public class StrSplit extends GenericUDTF {
private List<String> dataList = new ArrayList<>();
/**
* 定义输出的字段名和字段的类型
*/
@Override
public StructObjectInspector initialize(StructObjectInspector argOIs) throws UDFArgumentException {
// 定义输出的字段名
List<String> filedsNames = new ArrayList<>();
filedsNames.add("word");
// 定义对应字段的数据类型
List<ObjectInspector> filedOIS = new ArrayList<>();
// String类型
filedOIS.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(filedsNames,filedOIS);
}
/**
* StrSplit(filedName,",") 循环调用的方法
*/
@Override
public void process(Object[] objects) throws HiveException {
// 转换filedName字段的数据类型
String data = objects[0].toString();
// 转换","的数据类型
String splitKey = objects[1].toString();
// 切分字符串
String[] words = data.split(splitKey);
// 遍历写出字符
for (String word : words) {
dataList.clear();
dataList.add(word);
// 将数据写出, 注意要和filedOIS的数据类型相同
forward(dataList);
}
}
// 最后执行的方法,关闭操作
@Override
public void close() throws HiveException {
}
}
-- 添加jar包
hive (default)> add jar /root/spark_scala_maven.jar;
-- 创建函数
hive (default)> create temporary function splitStr as 'hiveFun.UDTF.StrSplit';
OK
Time taken: 0.007 seconds
-- 使用函数
hive (default)> select splitStr('a,b,c,d,e',',');
OK
word
a
b
c
d
e