SparkSQL 自定义聚合函数[强类型] & DSL

本文的前提条件: SparkSQL in Java
参考地址:User Defined Aggregate Functions (UDAFs)

1.声明列实体类

package cn.coreqi.entity;

import java.io.Serializable;

public class User implements Serializable {
    private String username;
    private Long age;

    public String getUsername() {
        return username;
    }

    public void setUsername(String username) {
        this.username = username;
    }

    public Long getAge() {
        return age;
    }

    public void setAge(Long age) {
        this.age = age;
    }

}

2.声明聚合函数缓存区强类型

package cn.coreqi.entity;

import java.io.Serializable;

public class Average implements Serializable {
    private long total;
    private long count;

    public Average() { }

    public Average(long total, long count) {
        this.total = total;
        this.count = count;
    }

    public long getTotal() {
        return total;
    }

    public void setTotal(long total) {
        this.total = total;
    }

    public long getCount() {
        return count;
    }

    public void setCount(long count) {
        this.count = count;
    }
}

3.聚合函数处理流程

package cn.coreqi.udaf;

import cn.coreqi.entity.Average;
import cn.coreqi.entity.User;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

/**
 * IN – 输入的数据类型。
 * BUF – 缓存区的数据类型。
 * OUT – 输出的数据类型。
 */
public class MyAvgUDAF2 extends Aggregator<User, Average, Long> {

    /**
     * 缓存区的初始化
     * @return
     */
    @Override
    public Average zero() {
        return new Average(0L,0L);
    }

    /**
     * 根据输入的数据更新缓存区的数据
     * @param b 缓存区数据
     * @param a 输入的数据
     * @return
     */
    @Override
    public Average reduce(Average b, User a) {
        b.setCount(b.getCount() + 1);
        b.setTotal(b.getTotal() + a.getAge());
        return b;
    }

    /**
     * 缓存区数据合并
     * 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
     * @param b1
     * @param b2
     * @return
     */
    @Override
    public Average merge(Average b1, Average b2) {
        b1.setTotal(b1.getTotal() + b2.getTotal());
        b1.setCount(b1.getCount() + b2.getCount());
        return b1;
    }

    /**
     * 计算逻辑,此处为计算平均值
     * @param reduction
     * @return
     */
    @Override
    public Long finish(Average reduction) {
        return reduction.getTotal() / reduction.getCount();
    }

    /**
     * 分布式计算,需要将数据在网络中传输
     * 缓存区的编码操作
     * @return
     */
    @Override
    public Encoder<Average> bufferEncoder() {
        return Encoders.bean(Average.class);
    }

    /**
     * 输出的编码操作
     * @return
     */
    @Override
    public Encoder<Long> outputEncoder() {
        return Encoders.LONG();
    }
}

4.使用

package cn.coreqi;

import cn.coreqi.entity.User;
import cn.coreqi.udaf.MyAvgUDAF2;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;

public class Main {
    public static void main(String[] args) {
        // 创建SparkConf对象
        SparkConf sparkConf = new SparkConf()
                .setMaster("local[*]")
                .setAppName("sparkSql");

        SparkSession spark = SparkSession
                .builder()
                .config(sparkConf)
                .getOrCreate();

        Dataset<Row> df = spark.read().json("datas/user.json");

        // 早期版本中,spark不能在sql中使用强类型UDAF操作,早期的UDAF强类型聚合函数使用DSL语法操作
        Dataset ds = df.as(Encoders.bean(User.class));

        // 将UDAF函数转换为查询的列对象
        TypedColumn<User, Long> udafColumn = new MyAvgUDAF2().toColumn();
        ds.select(udafColumn).show();


        // 关闭
        spark.close();
    }
}
posted @ 2024-01-14 21:52  SpringCore  阅读(31)  评论(0编辑  收藏  举报