SparkSQL 自定义聚合函数[强类型]
本文的前提条件: SparkSQL in Java
参考地址:User Defined Aggregate Functions (UDAFs)
1.自定义实体类
package cn.coreqi.entity;
import java.io.Serializable;
public class Average implements Serializable {
private long total;
private long count;
public Average() { }
public Average(long total, long count) {
this.total = total;
this.count = count;
}
public long getTotal() {
return total;
}
public void setTotal(long total) {
this.total = total;
}
public long getCount() {
return count;
}
public void setCount(long count) {
this.count = count;
}
}
2.自定义聚合函数
package cn.coreqi.udaf;
import cn.coreqi.entity.Average;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
/**
* IN – 输入的数据类型。
* BUF – 缓存区的数据类型。
* OUT – 输出的数据类型。
*/
public class MyAvgUDAF extends Aggregator<Long, Average, Long> {
/**
* 缓存区的初始化
* @return
*/
@Override
public Average zero() {
return new Average(0L,0L);
}
/**
* 根据输入的数据更新缓存区的数据
* @param b 缓存区数据
* @param a 输入的数据
* @return
*/
@Override
public Average reduce(Average b, Long a) {
b.setCount(b.getCount() + 1);
b.setTotal(b.getTotal() + a);
return b;
}
/**
* 缓存区数据合并
* 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
* @param b1
* @param b2
* @return
*/
@Override
public Average merge(Average b1, Average b2) {
b1.setTotal(b1.getTotal() + b2.getTotal());
b1.setCount(b1.getCount() + b2.getCount());
return b1;
}
/**
* 计算逻辑,此处为计算平均值
* @param reduction
* @return
*/
@Override
public Long finish(Average reduction) {
return reduction.getTotal() / reduction.getCount();
}
/**
* 分布式计算,需要将数据在网络中传输
* 缓存区的编码操作
* @return
*/
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
/**
* 输出的编码操作
* @return
*/
@Override
public Encoder<Long> outputEncoder() {
return Encoders.LONG();
}
}
3.使用
package cn.coreqi;
import cn.coreqi.udaf.MyAvgUDAF;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
public class Main {
public static void main(String[] args) {
// 创建SparkConf对象
SparkConf sparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("sparkSql");
SparkSession spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate();
Dataset<Row> df = spark.read().json("datas/user.json");
df.show();
// DataFrames => SQL
df.createOrReplaceTempView("user");
spark.udf().register("ageAvg", functions.udaf(new MyAvgUDAF(), Encoders.LONG()));
spark.sql("select ageAvg(age) from user").show();
// 关闭
spark.close();
}
}