Spark笔记之使用UDAF(User Defined Aggregate Function)
一、UDAF简介
先解释一下什么是UDAF(User Defined Aggregate Function),即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。
关于UDAF的一个误区
我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:
1 | select max (foo) from foobar group by bar; |
表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:
1 | select max (foo) from foobar; |
这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。
二、UDAF使用
2.1 继承UserDefinedAggregateFunction
使用UserDefinedAggregateFunction的套路:
1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现
2. 在spark中注册UDAF,为其绑定一个名字
3. 然后就可以在sql语句中使用上面绑定的名字调用
下面写一个计算平均值的UDAF例子,首先定义一个类继承UserDefinedAggregateFunction:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | package cc11001100.spark.sql.udaf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction { // 聚合函数的输入数据结构 override def inputSchema: StructType = StructType(StructField( "input" , LongType) :: Nil) // 缓存区数据结构 override def bufferSchema: StructType = StructType(StructField( "sum" , LongType) :: StructField( "count" , LongType) :: Nil) // 聚合函数返回值数据结构 override def dataType: DataType = DoubleType // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出 override def deterministic: Boolean = true // 初始化缓冲区 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer( 0 ) = 0L buffer( 1 ) = 0L } // 给聚合函数传入一条新数据进行处理 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (input.isNullAt( 0 )) return buffer( 0 ) = buffer.getLong( 0 ) + input.getLong( 0 ) buffer( 1 ) = buffer.getLong( 1 ) + 1 } // 合并聚合函数缓冲区 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1( 0 ) = buffer1.getLong( 0 ) + buffer2.getLong( 0 ) buffer1( 1 ) = buffer1.getLong( 1 ) + buffer2.getLong( 1 ) } // 计算最终结果 override def evaluate(buffer: Row): Any = buffer.getLong( 0 ).toDouble / buffer.getLong( 1 ) } |
然后注册并使用它:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | package cc11001100.spark.sql.udaf import org.apache.spark.sql.SparkSession object SparkSqlUDAFDemo_001 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master( "local[*]" ).appName( "SparkStudy" ).getOrCreate() spark.read.json( "data/user" ).createOrReplaceTempView( "v_user" ) spark.udf.register( "u_avg" , AverageUserDefinedAggregateFunction) // 将整张表看做是一个分组对求所有人的平均年龄 spark.sql( "select count(1) as count, u_avg(age) as avg_age from v_user" ).show() // 按照性别分组求平均年龄 spark.sql( "select sex, count(1) as count, u_avg(age) as avg_age from v_user group by sex" ).show() } } |
使用到的数据集:
1 2 3 4 5 6 | {"id": 1001, "name": "foo", "sex": "man", "age": 20} {"id": 1002, "name": "bar", "sex": "man", "age": 24} {"id": 1003, "name": "baz", "sex": "man", "age": 18} {"id": 1004, "name": "foo1", "sex": "woman", "age": 17} {"id": 1005, "name": "bar2", "sex": "woman", "age": 19} {"id": 1006, "name": "baz3", "sex": "woman", "age": 20} |
运行结果:
2.2 继承Aggregator
还有另一种方式就是继承Aggregator这个类,优点是可以带类型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | package cc11001100.spark.sql.udaf import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{Encoder, Encoders} /** * 计算平均值 * */ object AverageAggregator extends Aggregator[User, Average, Double] { // 初始化buffer override def zero: Average = Average(0L, 0L) // 处理一条新的记录 override def reduce(b: Average, a: User): Average = { b.sum += a.age b.count += 1L b } // 合并聚合buffer override def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum b1.count += b2.count b1 } // 减少中间数据传输 override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count override def bufferEncoder: Encoder[Average] = Encoders.product // 最终输出结果的类型 override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } /** * 计算平均值过程中使用的Buffer * * @param sum * @param count */ case class Average(var sum: Long, var count: Long) { } case class User(id: Long, name: String, sex: String, age: Long) { } |
调用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | package cc11001100.spark.sql.udaf import org.apache.spark.sql.SparkSession object AverageAggregatorDemo_001 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master( "local[*]" ).appName( "SparkStudy" ).getOrCreate() import spark.implicits._ val user = spark.read.json( "data/user" ).as[User] user.select(AverageAggregator.toColumn.name( "avg" )).show() } } |
运行结果:
.
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 智能桌面机器人:用.NET IoT库控制舵机并多方法播放表情
· Linux glibc自带哈希表的用例及性能测试
· 深入理解 Mybatis 分库分表执行原理
· 如何打造一个高并发系统?
· .NET Core GC压缩(compact_phase)底层原理浅谈
· 手把手教你在本地部署DeepSeek R1,搭建web-ui ,建议收藏!
· 新年开篇:在本地部署DeepSeek大模型实现联网增强的AI应用
· Janus Pro:DeepSeek 开源革新,多模态 AI 的未来
· 互联网不景气了那就玩玩嵌入式吧,用纯.NET开发并制作一个智能桌面机器人(三):用.NET IoT库
· 【非技术】说说2024年我都干了些啥