SparkSQL 自定义聚合函数[弱类型]

本文的前提条件: SparkSQL in Java

代码如下

1.自定义聚合函数

package cn.coreqi.udaf;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

/**
 * 自定义聚合函数类[弱类型],计算年龄的平均值
 * 1.集成UserDefinedAggregateFunction
 * 2.重写方法
 */
public class MyAvgUDAF1 extends UserDefinedAggregateFunction {

    /**
     * 定义输入数据的结构 - In
     * @return
     */
    @Override
    public StructType inputSchema() {
        return new StructType().add("age", DataTypes.LongType);   //年龄
    }

    /**
     * 缓存区数据的结构 - Buffer
     * 缓存区用于临时计算
     * @return
     */
    @Override
    public StructType bufferSchema() {
        return new StructType()
                .add("total",DataTypes.LongType)//累计的总年龄
                .add("count",DataTypes.LongType); //用户的数量
    }

    /**
     * 函数计算结果的数据类型 - Out
     * @return
     */
    @Override
    public DataType dataType() {
        return DataTypes.LongType;
    }

    /**
     * 函数的稳定性
     * 传入相同的参数结果是否相同
     * @return
     */
    @Override
    public boolean deterministic() {
        return true;
    }

    /**
     * 缓存区初始化
     * @param buffer
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,0L);    //初始化 total 为 0
        buffer.update(1,0L);    // 初始化 count 为 0
    }

    /**
     * 数据加载到缓存区后如何更新缓存区中的值
     * @param buffer
     * @param input
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        buffer.update(0,buffer.getLong(0) + input.getLong(0));  // 缓存区中的 total + 输入数据中的 age[输入结构中只有age一个参数,因此使用索引0即可取出age的值]
        buffer.update(1,buffer.getLong(1) + 1); // count++
    }

    /**
     * 缓存区数据合并
     * 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
     * @param buffer1
     * @param buffer2
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0));
        buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1));
    }

    /**
     * 计算逻辑,此处为计算平均值
     * @param buffer
     * @return
     */
    @Override
    public Object evaluate(Row buffer) {
        return buffer.getLong(0) / buffer.getLong(1);
    }
}

2.使用

package cn.coreqi;

import cn.coreqi.udaf.MyAvgUDAF1;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;

public class Main {
    public static void main(String[] args) {
        // 创建SparkConf对象
        SparkConf sparkConf = new SparkConf()
                .setMaster("local[*]")
                .setAppName("sparkSql");

        SparkSession spark = SparkSession
                .builder()
                .config(sparkConf)
                .getOrCreate();

        Dataset<Row> df = spark.read().json("datas/user.json");
        df.show();

        // DataFrames => SQL
        df.createOrReplaceTempView("user");

        spark.udf().register("ageAvg", new MyAvgUDAF1());
        spark.sql("select ageAvg(age) from user").show();

        // 关闭
        spark.close();
    }
}
posted @ 2024-01-14 15:33  SpringCore  阅读(13)  评论(0编辑  收藏  举报