SparkSQL 自定义聚合函数[弱类型]
本文的前提条件: SparkSQL in Java
代码如下
1.自定义聚合函数
package cn.coreqi.udaf;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
/**
* 自定义聚合函数类[弱类型],计算年龄的平均值
* 1.集成UserDefinedAggregateFunction
* 2.重写方法
*/
public class MyAvgUDAF1 extends UserDefinedAggregateFunction {
/**
* 定义输入数据的结构 - In
* @return
*/
@Override
public StructType inputSchema() {
return new StructType().add("age", DataTypes.LongType); //年龄
}
/**
* 缓存区数据的结构 - Buffer
* 缓存区用于临时计算
* @return
*/
@Override
public StructType bufferSchema() {
return new StructType()
.add("total",DataTypes.LongType)//累计的总年龄
.add("count",DataTypes.LongType); //用户的数量
}
/**
* 函数计算结果的数据类型 - Out
* @return
*/
@Override
public DataType dataType() {
return DataTypes.LongType;
}
/**
* 函数的稳定性
* 传入相同的参数结果是否相同
* @return
*/
@Override
public boolean deterministic() {
return true;
}
/**
* 缓存区初始化
* @param buffer
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0L); //初始化 total 为 0
buffer.update(1,0L); // 初始化 count 为 0
}
/**
* 数据加载到缓存区后如何更新缓存区中的值
* @param buffer
* @param input
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0,buffer.getLong(0) + input.getLong(0)); // 缓存区中的 total + 输入数据中的 age[输入结构中只有age一个参数,因此使用索引0即可取出age的值]
buffer.update(1,buffer.getLong(1) + 1); // count++
}
/**
* 缓存区数据合并
* 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
* @param buffer1
* @param buffer2
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0));
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1));
}
/**
* 计算逻辑,此处为计算平均值
* @param buffer
* @return
*/
@Override
public Object evaluate(Row buffer) {
return buffer.getLong(0) / buffer.getLong(1);
}
}
2.使用
package cn.coreqi;
import cn.coreqi.udaf.MyAvgUDAF1;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
public class Main {
public static void main(String[] args) {
// 创建SparkConf对象
SparkConf sparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("sparkSql");
SparkSession spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate();
Dataset<Row> df = spark.read().json("datas/user.json");
df.show();
// DataFrames => SQL
df.createOrReplaceTempView("user");
spark.udf().register("ageAvg", new MyAvgUDAF1());
spark.sql("select ageAvg(age) from user").show();
// 关闭
spark.close();
}
}
作者:奇
出处:https://www.cnblogs.com/fanqisoft/p/17963774
版权:本作品采用「本文版权归作者和博客园共有,欢迎转载,但必须给出原文链接,并保留此段声明,否则保留追究法律责任的权利。」许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
2023-01-14 Docker 部署 Zabbix
2023-01-14 Centos8 部署 Zabbix