自定义UDAF函数(多对一函数)

package SparkSQL.fun.registerfum
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
/**
* 实现计算每个班级的平均年龄
*/
object registerfun2 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("registfun2").setMaster("local[*]")
val session = SparkSession.builder().config(conf).getOrCreate()
import session.implicits._
val dataset: Dataset[student] = session.createDataset(Array(
student("zs", "c001", 21, "男"),
student("ls", "c001", 22, "女"),
student("ww", "c001", 23, "男"),
student("ml", "c002", 20, "女"),
student("zb", "c002", 23, "男")
))
dataset.createOrReplaceTempView("student")
session.udf.register("avg_clz_age", new MyAvg())
val frame = session.sql("select clz, avg_clz_age(age) avg from student group by clz")
frame.show()
session.stop()
}
}
// Spark SQL中自定义UDAF聚合函数 实现avg功能
class MyAvg extends UserDefinedAggregateFunction {
/**
* inputSchema 代表的是聚合函数输入的参数类型以及输入的参数的值的个数
* @return
*/
override def inputSchema: StructType = {
StructType(Array(
StructField("input", DataTypes.LongType)
))
}
/**
* bufferSchema 代表的是聚合中间结果的类型
* @return
*/
override def bufferSchema: StructType = {
StructType(Array(
StructField("sum", DataTypes.LongType),
StructField("count", DataTypes.LongType)
))
}
/**
* dataType 函数返回的数据类型
* @return
*/
override def dataType: DataType = DataTypes.DoubleType
/**
* deterministic 相同的值是否返回相同的结果
* @return
*/
override def deterministic: Boolean = true
/**
* initialize 初始化方法
* 初始化 bufferSchema定义的中间结果的缓存数据(初始化对应个数的初始值)
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//0索引代表的就是bufferSchema函数中定义的第一个数据类型 sum
buffer(0) = 0L
//1索引代表的就是bufferSchema函数中定义的第二个数据类型 count
buffer(1) = 0L
}
/**
* update函数就代表我函数被调用之后 输入了一个值 然后你如何把这个值加到缓存的结果中
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val age = input.getAs[Long](0)
//将以前累加的sum值+现在输入的值 就是新的累加和
buffer(0) = buffer.getLong(0) + age
buffer(1) = buffer.getLong(1) + 1L
}
/**
* Spark SQL是分布式运行 函数的话也是在每一个分区上更新值,更新完成之后,将多个分区的值合并起来
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getLong(1) + buffer2.getAs[Long](1)
}
/**
* 这个是聚合函数的核心 根据Buffer汇总的结果计算最终的聚合结果,返会一个值
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
val sum = buffer.getAs[Long](0).toDouble
val count = buffer.getAs[Long](1).toDouble
sum / count
}
}
posted @   jsqup  阅读(46)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示