SparkSQL 自定义聚合函数[强类型]

本文的前提条件: SparkSQL in Java
参考地址:User Defined Aggregate Functions (UDAFs)

1.自定义实体类

package cn.coreqi.entity;

import java.io.Serializable;

public class Average implements Serializable {
    private long total;
    private long count;

    public Average() { }

    public Average(long total, long count) {
        this.total = total;
        this.count = count;
    }

    public long getTotal() {
        return total;
    }

    public void setTotal(long total) {
        this.total = total;
    }

    public long getCount() {
        return count;
    }

    public void setCount(long count) {
        this.count = count;
    }
}

2.自定义聚合函数

package cn.coreqi.udaf;

import cn.coreqi.entity.Average;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

/**
 * IN – 输入的数据类型。
 * BUF – 缓存区的数据类型。
 * OUT – 输出的数据类型。
 */
public class MyAvgUDAF extends Aggregator<Long, Average, Long> {

    /**
     * 缓存区的初始化
     * @return
     */
    @Override
    public Average zero() {
        return new Average(0L,0L);
    }

    /**
     * 根据输入的数据更新缓存区的数据
     * @param b 缓存区数据
     * @param a 输入的数据
     * @return
     */
    @Override
    public Average reduce(Average b, Long a) {
        b.setCount(b.getCount() + 1);
        b.setTotal(b.getTotal() + a);
        return b;
    }

    /**
     * 缓存区数据合并
     * 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
     * @param b1
     * @param b2
     * @return
     */
    @Override
    public Average merge(Average b1, Average b2) {
        b1.setTotal(b1.getTotal() + b2.getTotal());
        b1.setCount(b1.getCount() + b2.getCount());
        return b1;
    }

    /**
     * 计算逻辑,此处为计算平均值
     * @param reduction
     * @return
     */
    @Override
    public Long finish(Average reduction) {
        return reduction.getTotal() / reduction.getCount();
    }

    /**
     * 分布式计算,需要将数据在网络中传输
     * 缓存区的编码操作
     * @return
     */
    @Override
    public Encoder<Average> bufferEncoder() {
        return Encoders.bean(Average.class);
    }

    /**
     * 输出的编码操作
     * @return
     */
    @Override
    public Encoder<Long> outputEncoder() {
        return Encoders.LONG();
    }
}

3.使用

package cn.coreqi;

import cn.coreqi.udaf.MyAvgUDAF;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;

public class Main {
    public static void main(String[] args) {
        // 创建SparkConf对象
        SparkConf sparkConf = new SparkConf()
                .setMaster("local[*]")
                .setAppName("sparkSql");

        SparkSession spark = SparkSession
                .builder()
                .config(sparkConf)
                .getOrCreate();

        Dataset<Row> df = spark.read().json("datas/user.json");
        df.show();

        // DataFrames => SQL
        df.createOrReplaceTempView("user");

        spark.udf().register("ageAvg", functions.udaf(new MyAvgUDAF(), Encoders.LONG()));
        spark.sql("select ageAvg(age) from user").show();

        // 关闭
        spark.close();
    }
}

作者:奇

出处:https://www.cnblogs.com/fanqisoft/p/17963980

版权:本作品采用「本文版权归作者和博客园共有,欢迎转载,但必须给出原文链接,并保留此段声明,否则保留追究法律责任的权利。」许可协议进行许可。

posted @   SpringCore  阅读(53)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
历史上的今天:
2023-01-14 Docker 部署 Zabbix
2023-01-14 Centos8 部署 Zabbix
more_horiz
keyboard_arrow_up light_mode palette
选择主题
点击右上角即可分享
微信分享提示