spark-sql自定义函数UDF和UDAF

1 UDF对每个值进行处理;

2 UDAF对分组后的每个值处理(必须分组)

    SparkConf sparkConf = new SparkConf()
                .setMaster("local")
                .setAppName("MySqlTest");

        JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);

        SQLContext sqlContext = new SQLContext(javaSparkContext);

        List<String> list = new ArrayList<String>();
        list.add("2018-9-9,1,ab");
        list.add("2018-5-9,1124,abg");
        list.add("2018-9-9,1125,abc");
        list.add("2018-5-9,1126,abh");
        list.add("2016-10-9,1127,abc");
        list.add("2016-10-9,1127,abcd");
        list.add("2016-10-9,1127,abcder");

        JavaRDD<String> rdd_list = javaSparkContext.parallelize(list, 5);

        JavaRDD<Row> rdd_row_list = rdd_list.map(new Function<String, Row>() {
            @Override
            public Row call(String s) throws Exception {
                return RowFactory.create(s.split(",")[0], Long.parseLong(s.split(",")[1]), s.split(",")[2]);//转换成一个row对象
            }
        });

        List<StructField> structFieldList = new ArrayList<StructField>();
        structFieldList.add(DataTypes.createStructField("date", DataTypes.StringType, true));
        structFieldList.add(DataTypes.createStructField("s", DataTypes.LongType, true));
        structFieldList.add(DataTypes.createStructField("str", DataTypes.StringType, true));
        StructType dyType = DataTypes.createStructType(structFieldList);

        DataFrame df_dyType = sqlContext.createDataFrame(rdd_row_list, dyType);

        df_dyType.registerTempTable("tmp_req");
        df_dyType.show();

        //1,注册一个简单用户自定义函数
        sqlContext.udf().register("zzq123", new UDF1<String, Integer>() {
            @Override
            public Integer call(String str) throws Exception {
                return str.length();
            }
        }, DataTypes.IntegerType);

        DataFrame df_group = sqlContext.sql("select date,s,zzq123(date) as zzq123 from tmp_req ");//UDF如果没有指定名称,则随机名称
        df_group.show();

        //1,注册一个复杂的用户自定义聚合函数
        sqlContext.udf().register("zzq_agg", new StringLen());//zzq_agg函数计算出分组后本组所有字符串总长度
        DataFrame df_group_agg = sqlContext.sql("select date,zzq_agg(str) strSum  from tmp_req group by date ");//UDAF为聚合情况下使用
        df_group_agg.show();

UDAF实体:

public class StringLen extends UserDefinedAggregateFunction {
    @Override
    public StructType inputSchema() {//inputSchema指的是输入的数据类型
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("_string", DataTypes.StringType, true));
        return DataTypes.createStructType(fields);
    }

    @Override
    public StructType bufferSchema() {//bufferSchema指的是  中间进行聚合时  所处理的数据类型
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("_len", DataTypes.IntegerType, true));
        return DataTypes.createStructType(fields);
    }

    @Override
    public DataType dataType() {//dataType指的是函数返回值的类型
        return DataTypes.IntegerType;
    }

    @Override
    public boolean deterministic() {//一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的
        return true;
    }

    /**
     * 对于每个分组的数据进行最原始的初始化操作
     *
     * @param buffer
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0);//初始化的时候初始最开始的字符串的长度
    }

    /**
     * 用输入数据input更新buffer值,类似于combineByKey
     *
     * @param buffer
     * @param input
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {//分组后的每个值处理方法
        buffer.update(0, ((Integer) buffer.getAs(0)) + input.getAs(0).toString().length());//返回自己的长度
    }

    /**
     * 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
     * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节
     *
     * @param buffer1
     * @param buffer2
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {//相当于shuffle环节,将每组在不同executor上的数据进行combiner
        buffer1.update(0, ((Integer) buffer1.getAs(0)) + ((Integer) buffer2.getAs(0)));//两次的字符串长度相加
    }

    /**
     * 计算并返回最终的聚合结果
     *
     * @param buffer
     * @return
     */
    @Override
    public Object evaluate(Row buffer) {
        return buffer.getInt(0);
    }
}

 

posted @ 2018-04-09 12:07  soft.push("zzq")  Views(672)  Comments(0Edit  收藏  举报