spark-sql自定义函数UDF和UDAF
1 UDF对每个值进行处理;
2 UDAF对分组后的每个值处理(必须分组)
SparkConf sparkConf = new SparkConf() .setMaster("local") .setAppName("MySqlTest"); JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf); SQLContext sqlContext = new SQLContext(javaSparkContext); List<String> list = new ArrayList<String>(); list.add("2018-9-9,1,ab"); list.add("2018-5-9,1124,abg"); list.add("2018-9-9,1125,abc"); list.add("2018-5-9,1126,abh"); list.add("2016-10-9,1127,abc"); list.add("2016-10-9,1127,abcd"); list.add("2016-10-9,1127,abcder"); JavaRDD<String> rdd_list = javaSparkContext.parallelize(list, 5); JavaRDD<Row> rdd_row_list = rdd_list.map(new Function<String, Row>() { @Override public Row call(String s) throws Exception { return RowFactory.create(s.split(",")[0], Long.parseLong(s.split(",")[1]), s.split(",")[2]);//转换成一个row对象 } }); List<StructField> structFieldList = new ArrayList<StructField>(); structFieldList.add(DataTypes.createStructField("date", DataTypes.StringType, true)); structFieldList.add(DataTypes.createStructField("s", DataTypes.LongType, true)); structFieldList.add(DataTypes.createStructField("str", DataTypes.StringType, true)); StructType dyType = DataTypes.createStructType(structFieldList); DataFrame df_dyType = sqlContext.createDataFrame(rdd_row_list, dyType); df_dyType.registerTempTable("tmp_req"); df_dyType.show(); //1,注册一个简单用户自定义函数 sqlContext.udf().register("zzq123", new UDF1<String, Integer>() { @Override public Integer call(String str) throws Exception { return str.length(); } }, DataTypes.IntegerType); DataFrame df_group = sqlContext.sql("select date,s,zzq123(date) as zzq123 from tmp_req ");//UDF如果没有指定名称,则随机名称 df_group.show(); //1,注册一个复杂的用户自定义聚合函数 sqlContext.udf().register("zzq_agg", new StringLen());//zzq_agg函数计算出分组后本组所有字符串总长度 DataFrame df_group_agg = sqlContext.sql("select date,zzq_agg(str) strSum from tmp_req group by date ");//UDAF为聚合情况下使用 df_group_agg.show();
UDAF实体:
public class StringLen extends UserDefinedAggregateFunction { @Override public StructType inputSchema() {//inputSchema指的是输入的数据类型 List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("_string", DataTypes.StringType, true)); return DataTypes.createStructType(fields); } @Override public StructType bufferSchema() {//bufferSchema指的是 中间进行聚合时 所处理的数据类型 List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("_len", DataTypes.IntegerType, true)); return DataTypes.createStructType(fields); } @Override public DataType dataType() {//dataType指的是函数返回值的类型 return DataTypes.IntegerType; } @Override public boolean deterministic() {//一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的 return true; } /** * 对于每个分组的数据进行最原始的初始化操作 * * @param buffer */ @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0);//初始化的时候初始最开始的字符串的长度 } /** * 用输入数据input更新buffer值,类似于combineByKey * * @param buffer * @param input */ @Override public void update(MutableAggregationBuffer buffer, Row input) {//分组后的每个值处理方法 buffer.update(0, ((Integer) buffer.getAs(0)) + input.getAs(0).toString().length());//返回自己的长度 } /** * 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节 * * @param buffer1 * @param buffer2 */ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {//相当于shuffle环节,将每组在不同executor上的数据进行combiner buffer1.update(0, ((Integer) buffer1.getAs(0)) + ((Integer) buffer2.getAs(0)));//两次的字符串长度相加 } /** * 计算并返回最终的聚合结果 * * @param buffer * @return */ @Override public Object evaluate(Row buffer) { return buffer.getInt(0); } }