SparkSQL 自定义聚合函数[强类型]

本文的前提条件: SparkSQL in Java
参考地址:User Defined Aggregate Functions (UDAFs)

1.自定义实体类

package cn.coreqi.entity;

import java.io.Serializable;

public class Average implements Serializable {
    private long total;
    private long count;

    public Average() { }

    public Average(long total, long count) {
        this.total = total;
        this.count = count;
    }

    public long getTotal() {
        return total;
    }

    public void setTotal(long total) {
        this.total = total;
    }

    public long getCount() {
        return count;
    }

    public void setCount(long count) {
        this.count = count;
    }
}

2.自定义聚合函数

package cn.coreqi.udaf;

import cn.coreqi.entity.Average;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

/**
 * IN – 输入的数据类型。
 * BUF – 缓存区的数据类型。
 * OUT – 输出的数据类型。
 */
public class MyAvgUDAF extends Aggregator<Long, Average, Long> {

    /**
     * 缓存区的初始化
     * @return
     */
    @Override
    public Average zero() {
        return new Average(0L,0L);
    }

    /**
     * 根据输入的数据更新缓存区的数据
     * @param b 缓存区数据
     * @param a 输入的数据
     * @return
     */
    @Override
    public Average reduce(Average b, Long a) {
        b.setCount(b.getCount() + 1);
        b.setTotal(b.getTotal() + a);
        return b;
    }

    /**
     * 缓存区数据合并
     * 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
     * @param b1
     * @param b2
     * @return
     */
    @Override
    public Average merge(Average b1, Average b2) {
        b1.setTotal(b1.getTotal() + b2.getTotal());
        b1.setCount(b1.getCount() + b2.getCount());
        return b1;
    }

    /**
     * 计算逻辑,此处为计算平均值
     * @param reduction
     * @return
     */
    @Override
    public Long finish(Average reduction) {
        return reduction.getTotal() / reduction.getCount();
    }

    /**
     * 分布式计算,需要将数据在网络中传输
     * 缓存区的编码操作
     * @return
     */
    @Override
    public Encoder<Average> bufferEncoder() {
        return Encoders.bean(Average.class);
    }

    /**
     * 输出的编码操作
     * @return
     */
    @Override
    public Encoder<Long> outputEncoder() {
        return Encoders.LONG();
    }
}

3.使用

package cn.coreqi;

import cn.coreqi.udaf.MyAvgUDAF;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;

public class Main {
    public static void main(String[] args) {
        // 创建SparkConf对象
        SparkConf sparkConf = new SparkConf()
                .setMaster("local[*]")
                .setAppName("sparkSql");

        SparkSession spark = SparkSession
                .builder()
                .config(sparkConf)
                .getOrCreate();

        Dataset<Row> df = spark.read().json("datas/user.json");
        df.show();

        // DataFrames => SQL
        df.createOrReplaceTempView("user");

        spark.udf().register("ageAvg", functions.udaf(new MyAvgUDAF(), Encoders.LONG()));
        spark.sql("select ageAvg(age) from user").show();

        // 关闭
        spark.close();
    }
}
posted @ 2024-01-14 17:45  SpringCore  阅读(37)  评论(0编辑  收藏  举报