大数据学习日志——scala实现sparkSQL的UDAF
UDAF全称时User Defined Aggregate Function,即用户自定义集合函数,就是多个输入值一个输出值的函数。
首先思考聚合函数,怎么使用函数把多个值合成一个值。
先不考虑sparkSQL提供的开发api,考虑多个值聚合,首先得想到有个函数,接收多个单独的数值,进行函数计算,可以是一个函数一次性接收所有数值进行计算;除此之外呢,可以进行部分数值计算,然后放入缓存中,然后再对新得到的值和缓存中的值进行计算。
在这里把这两种函数列出来:
A方案:将所有要聚合的值放入一个函数中做运算
B方案:当取得部分值时候就先进行计算放入缓存,之后一取到需要运算的值,变和缓存对象进行运算,最后当所有值都运算完后,再对缓存中的值进行处理然后输出
这里讨论下A、B方案优劣性。
1.描述中就可以看出A方案实现简单,B方案需要写多个函数进行处理。
2.A方案需要将所有数据转移到一个函数中运算,要是需要聚合的值很多,便会生成一个很占内存的数据对象,要是spark是在多台机器上运算,就会有网络IO传输开销。B方案只需要多一个缓存对象,占用内存很小,对于分executor运算,也可以在各自的机器上存在缓存数据,最后对缓存和缓存间进行运算,这时候需要多写一个缓存间运算的函数。
A、B方案的优劣简单探讨后便明白,A方案除了编写代码简单以外毫无优势,实际上sparkSQL的UDAF便是采用改进过的B方案。
通过上面的优劣和实际运行情况的讨论,便知道运算时候除了缓存数据和新带运算数据的处理计算外还需要对多个缓存对象进行计算,然后生成新的缓存对象(这里经过上述讨论,不再阐述为什么不是多个缓存对象放到一个函数里进行计算而是生成新缓存对象)。
既然现在明白了采用的时改进后的B方案,那么再进一步考虑实际的问题,缓存对象要如何初始化,在上述描述中只是简单的说明了采用部分数据先运算成为缓存对象的方案,实际上在代码实现中,sparkSQL提供的函数并没有这一步,而是直接通过给缓存对象设置默认初始值的函数。
经过上述讨论已经明白了运算的基本步骤:
初始化缓存数据——>缓存数据和新对象的计算——>缓存对象的计算——>对缓存对象的处理输出
但是现在还不能直接开始写代码,对于一个稳定的,能在sql中使用的函数,还需要给整个运算搭建更稳固的环境,sparkSQL中仍需要用户重写一些函数:
输入的参数类型、缓存数据的内存对象类型、函数返回值类型、是否保持幂等性。
最后可以开始写代码,继承UserDefinedAggregateFunction函数,按照注释对其中的函数进行重写即可,不再描述写代码的步骤。
一下是一个实现了取平均值的udaf
1 package com.saltfish.run 2 3 import org.apache.spark.sql.Row 4 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 5 import org.apache.spark.sql.types._ 6 7 object Avg_UDAF extends UserDefinedAggregateFunction { 8 //输入的参数类型 9 override def inputSchema: StructType = { 10 StructType(Array(StructField("inputValue", DoubleType))) 11 } 12 13 //缓冲区中的值,用来记录上次处理的值 14 override def bufferSchema: StructType = { 15 16 StructType(Array( 17 StructField("totalValue", DoubleType), 18 StructField("totalCount", IntegerType) 19 )) 20 } 21 22 //函数的返回值类型 23 override def dataType: DataType = { 24 25 DoubleType 26 } 27 28 //设置多次运行该函数,传入相同输入值,是否返回相同结果 29 //默认true 30 override def deterministic: Boolean = { 31 32 true 33 } 34 35 //缓冲区的初始值 36 //对应缓冲区数据位置,初始值 37 override def initialize(buffer: MutableAggregationBuffer): Unit = { 38 39 buffer.update(0, 0.0) 40 buffer.update(1, 0) 41 } 42 43 //更新缓存的值,拿出一条新记录,和缓存的值,两个值计算,放入缓存 44 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 45 46 val iv = input.getDouble(0) 47 val bv = buffer.getDouble(0) 48 val bc = buffer.getInt(1) 49 //更新回去 50 buffer.update(0, iv + bv) 51 buffer.update(1, bc + 1) 52 } 53 54 //两个缓存中的值处理 55 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 56 57 val bv1 = buffer1.getDouble(0) 58 val bc1 = buffer1.getInt(1) 59 60 val bv2 = buffer2.getDouble(0) 61 val bc2 = buffer2.getInt(1) 62 63 buffer1.update(0, bv1 + bv2) 64 buffer1.update(1, bc1 + bc2) 65 } 66 67 //最终运算 68 override def evaluate(buffer: Row): Any = { 69 70 val bv = buffer.getDouble(0) 71 val bc = buffer.getInt(1) 72 73 bv / bc 74 } 75 }
posted on 2019-04-11 21:00 SaltFishYe 阅读(495) 评论(0) 编辑 收藏 举报