SparkSQL 自定义聚合函数[弱类型]
本文的前提条件: SparkSQL in Java
代码如下
1.自定义聚合函数
package cn.coreqi.udaf;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
/**
* 自定义聚合函数类[弱类型],计算年龄的平均值
* 1.集成UserDefinedAggregateFunction
* 2.重写方法
*/
public class MyAvgUDAF1 extends UserDefinedAggregateFunction {
/**
* 定义输入数据的结构 - In
* @return
*/
@Override
public StructType inputSchema() {
return new StructType().add("age", DataTypes.LongType); //年龄
}
/**
* 缓存区数据的结构 - Buffer
* 缓存区用于临时计算
* @return
*/
@Override
public StructType bufferSchema() {
return new StructType()
.add("total",DataTypes.LongType)//累计的总年龄
.add("count",DataTypes.LongType); //用户的数量
}
/**
* 函数计算结果的数据类型 - Out
* @return
*/
@Override
public DataType dataType() {
return DataTypes.LongType;
}
/**
* 函数的稳定性
* 传入相同的参数结果是否相同
* @return
*/
@Override
public boolean deterministic() {
return true;
}
/**
* 缓存区初始化
* @param buffer
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0L); //初始化 total 为 0
buffer.update(1,0L); // 初始化 count 为 0
}
/**
* 数据加载到缓存区后如何更新缓存区中的值
* @param buffer
* @param input
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0,buffer.getLong(0) + input.getLong(0)); // 缓存区中的 total + 输入数据中的 age[输入结构中只有age一个参数,因此使用索引0即可取出age的值]
buffer.update(1,buffer.getLong(1) + 1); // count++
}
/**
* 缓存区数据合并
* 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
* @param buffer1
* @param buffer2
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0));
buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1));
}
/**
* 计算逻辑,此处为计算平均值
* @param buffer
* @return
*/
@Override
public Object evaluate(Row buffer) {
return buffer.getLong(0) / buffer.getLong(1);
}
}
2.使用
package cn.coreqi;
import cn.coreqi.udaf.MyAvgUDAF1;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
public class Main {
public static void main(String[] args) {
// 创建SparkConf对象
SparkConf sparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("sparkSql");
SparkSession spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate();
Dataset<Row> df = spark.read().json("datas/user.json");
df.show();
// DataFrames => SQL
df.createOrReplaceTempView("user");
spark.udf().register("ageAvg", new MyAvgUDAF1());
spark.sql("select ageAvg(age) from user").show();
// 关闭
spark.close();
}
}