Spark SQL的两种用户自定义聚合函数(UDAF)
一、概述
DataFrames的内置函数提供了常见的聚合函数,比如count(), countDistinct(), avg(), max(), min()等,但是这些函数是为DataFrames而设计的,Spark SQL也有适用于强类型的Datasets的类型安全的函数。此外,用户也可以自定义聚合函数。自定义聚合函数有两种类型,一种是无类型的自定义聚合函数(适用于DataFrame),另一种是安全类型的自定义聚合函数(适用于DataSet)。
二、两种UDAF的方式
1.无类型的用户UDAF
继承UserDefinedAggregateFunction抽象类,实现无类型的自定义聚集函数
package com.company.sparksql
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object UserDefinedUntypedAggregation {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("UserDefinedUntypedAggregation")
.master("local")
.getOrCreate()
Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
Logger.getLogger("org.apache.hadoop").setLevel(Level.OFF)
// 注册函数
spark.udf.register("myAverage", MyAverage)
val df = spark.read.json("file:///E:/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
spark.stop()
}
/*
* 继承UserDefinedAggregateFunction抽象类
* 实现无类型的自定义聚集函数
* 该函数的作用是求平均值
*/
object MyAverage extends UserDefinedAggregateFunction {
//聚集函数输入参数的数据类型(:: Nil表示在此之前插入空列表)
def inputSchema: StructType = StructType(StructField("inputCol", LongType) :: Nil)
//缓冲值的数据类型
def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
//返回值的数据类型
def dataType: DataType = DoubleType
//相同的输入是否总是返回相同的输出
def deterministic: Boolean = true
// 初始化缓冲值
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L //初始化sum的值
buffer(1) = 0L //初始化count的值
}
//根据新的输入,更新缓冲值
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0) //更新sum的值
buffer(1) = buffer.getLong(1) + 1 //更新count的值
}
}
//对缓冲值进行合并,将更新的值存储到buffer1
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终的结果
def evaluate(buffer: Row): Double = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
}
2.类型安全的UDAF
继承Aggregator抽象类,下面的案例是求某个字段的平均值
package com.company.sparksql
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
object UserDefinedTypedAggregation {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("UserDefinedTypedAggregation")
.master("local")
.getOrCreate()
Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
Logger.getLogger("org.apache.hadoop").setLevel(Level.OFF)
//导入隐式转换
import spark.implicits._
//创建DataSet
val employeeDS = spark.read.json("file:///E:/employees.json").as[Employee]
employeeDS.show()
//将函数转换成`TypedColumn`,可以在DataSet中使用
//该列的别称为average_salary
val averageSalary = MyAverage.toColumn.name("average_salary")
val res = employeeDS.select(averageSalary)
res.show()
}
//样例类Employee,对应的json格式为:{"name":"Michael", "salary":3000}
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
/*
* Employee:聚合函数输入的类型
* Average:聚合中间值的数据类型
* Double:返回结果值的类型
* */
object MyAverage extends Aggregator[Employee, Average, Double] {
//初始化初值,应满足任何 b + zero = b
def zero: Average = Average(0L, 0L)
//Combine two values to produce a new value. For performance, the function may modify `b` and
// * return it instead of constructing new object for b.
/**
* 合并两个值,产生一个新值
* 考虑到性能的问题,该函数应修改buffer的值并返回之
*
* @param buffer :中间值
* @param employee :新值
* @return Average:返回值
*/
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
/**
* 合并两个中间值
*
* @param buffer1 :中间值1
* @param buffer2 :中间值2
* @return Average :返回值类型
*/
def merge(buffer1: Average, buffer2: Average): Average = {
buffer1.sum += buffer2.sum
buffer1.count += buffer2.count
buffer1
}
/**
* 转换合并值的输出
*
* @param reduction :最终的合并值
* @return Double:返回值类型
*/
def finish(reduction: Average): Double = {
reduction.sum.toDouble / reduction.count
}
//中间值类型的编码器(Encoder)
//该Encoder适用于Scala的product类型 (tuples, case classes,等等)
def bufferEncoder: Encoder[Average] = Encoders.product
//结果值类型的编码器(Encoder)
//该Encoder适用于Scala的Double类型
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
}
三、所用的数据集
{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}
参考:http://spark.apache.org/docs/latest/sql-programming-guide.html
全文完
公众号「大数据技术与数仓」
专注分享数据仓库与大数据技术