自定义UDAF函数(多对一函数)
package SparkSQL.fun.registerfum import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} import org.apache.spark.sql.{Dataset, Row, SparkSession} /** * 实现计算每个班级的平均年龄 */ object registerfun2 { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("registfun2").setMaster("local[*]") val session = SparkSession.builder().config(conf).getOrCreate() import session.implicits._ val dataset: Dataset[student] = session.createDataset(Array( student("zs", "c001", 21, "男"), student("ls", "c001", 22, "女"), student("ww", "c001", 23, "男"), student("ml", "c002", 20, "女"), student("zb", "c002", 23, "男") )) dataset.createOrReplaceTempView("student") session.udf.register("avg_clz_age", new MyAvg()) val frame = session.sql("select clz, avg_clz_age(age) avg from student group by clz") frame.show() session.stop() } } // Spark SQL中自定义UDAF聚合函数 实现avg功能 class MyAvg extends UserDefinedAggregateFunction { /** * inputSchema 代表的是聚合函数输入的参数类型以及输入的参数的值的个数 * @return */ override def inputSchema: StructType = { StructType(Array( StructField("input", DataTypes.LongType) )) } /** * bufferSchema 代表的是聚合中间结果的类型 * @return */ override def bufferSchema: StructType = { StructType(Array( StructField("sum", DataTypes.LongType), StructField("count", DataTypes.LongType) )) } /** * dataType 函数返回的数据类型 * @return */ override def dataType: DataType = DataTypes.DoubleType /** * deterministic 相同的值是否返回相同的结果 * @return */ override def deterministic: Boolean = true /** * initialize 初始化方法 * 初始化 bufferSchema定义的中间结果的缓存数据(初始化对应个数的初始值) * @param buffer */ override def initialize(buffer: MutableAggregationBuffer): Unit = { //0索引代表的就是bufferSchema函数中定义的第一个数据类型 sum buffer(0) = 0L //1索引代表的就是bufferSchema函数中定义的第二个数据类型 count buffer(1) = 0L } /** * update函数就代表我函数被调用之后 输入了一个值 然后你如何把这个值加到缓存的结果中 * @param buffer * @param input */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val age = input.getAs[Long](0) //将以前累加的sum值+现在输入的值 就是新的累加和 buffer(0) = buffer.getLong(0) + age buffer(1) = buffer.getLong(1) + 1L } /** * Spark SQL是分布式运行 函数的话也是在每一个分区上更新值,更新完成之后,将多个分区的值合并起来 * @param buffer1 * @param buffer2 */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getAs[Long](0) buffer1(1) = buffer1.getLong(1) + buffer2.getAs[Long](1) } /** * 这个是聚合函数的核心 根据Buffer汇总的结果计算最终的聚合结果,返会一个值 * @param buffer * @return */ override def evaluate(buffer: Row): Any = { val sum = buffer.getAs[Long](0).toDouble val count = buffer.getAs[Long](1).toDouble sum / count } }
本文来自博客园,作者:jsqup,转载请注明原文链接:https://www.cnblogs.com/jsqup/p/16659625.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律