SparkSQL 自定义聚合函数[强类型] & DSL
本文的前提条件: SparkSQL in Java
参考地址:User Defined Aggregate Functions (UDAFs)
1.声明列实体类
package cn.coreqi.entity;
import java.io.Serializable;
public class User implements Serializable {
private String username;
private Long age;
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public Long getAge() {
return age;
}
public void setAge(Long age) {
this.age = age;
}
}
2.声明聚合函数缓存区强类型
package cn.coreqi.entity;
import java.io.Serializable;
public class Average implements Serializable {
private long total;
private long count;
public Average() { }
public Average(long total, long count) {
this.total = total;
this.count = count;
}
public long getTotal() {
return total;
}
public void setTotal(long total) {
this.total = total;
}
public long getCount() {
return count;
}
public void setCount(long count) {
this.count = count;
}
}
3.聚合函数处理流程
package cn.coreqi.udaf;
import cn.coreqi.entity.Average;
import cn.coreqi.entity.User;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
/**
* IN – 输入的数据类型。
* BUF – 缓存区的数据类型。
* OUT – 输出的数据类型。
*/
public class MyAvgUDAF2 extends Aggregator<User, Average, Long> {
/**
* 缓存区的初始化
* @return
*/
@Override
public Average zero() {
return new Average(0L,0L);
}
/**
* 根据输入的数据更新缓存区的数据
* @param b 缓存区数据
* @param a 输入的数据
* @return
*/
@Override
public Average reduce(Average b, User a) {
b.setCount(b.getCount() + 1);
b.setTotal(b.getTotal() + a.getAge());
return b;
}
/**
* 缓存区数据合并
* 分布式计算,会有多个缓存区,最终多个缓存区需要合并到一起
* @param b1
* @param b2
* @return
*/
@Override
public Average merge(Average b1, Average b2) {
b1.setTotal(b1.getTotal() + b2.getTotal());
b1.setCount(b1.getCount() + b2.getCount());
return b1;
}
/**
* 计算逻辑,此处为计算平均值
* @param reduction
* @return
*/
@Override
public Long finish(Average reduction) {
return reduction.getTotal() / reduction.getCount();
}
/**
* 分布式计算,需要将数据在网络中传输
* 缓存区的编码操作
* @return
*/
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
/**
* 输出的编码操作
* @return
*/
@Override
public Encoder<Long> outputEncoder() {
return Encoders.LONG();
}
}
4.使用
package cn.coreqi;
import cn.coreqi.entity.User;
import cn.coreqi.udaf.MyAvgUDAF2;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
public class Main {
public static void main(String[] args) {
// 创建SparkConf对象
SparkConf sparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("sparkSql");
SparkSession spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate();
Dataset<Row> df = spark.read().json("datas/user.json");
// 早期版本中,spark不能在sql中使用强类型UDAF操作,早期的UDAF强类型聚合函数使用DSL语法操作
Dataset ds = df.as(Encoders.bean(User.class));
// 将UDAF函数转换为查询的列对象
TypedColumn<User, Long> udafColumn = new MyAvgUDAF2().toColumn();
ds.select(udafColumn).show();
// 关闭
spark.close();
}
}