SparkSQL自定义强类型聚合函数
自定义强类型聚合函数跟自定义无类型聚合函数的操作类似,相对的,实现自定义强类型聚合函数则要继承org.apache.spark.sql.expressions.Aggregator。强类型的优点在于:其内部与特定数据集紧密结合,增强了紧密型、安全性,但由于其紧凑的特性,降低了适用性。
准备employ.txt文件:
Michael,3000
Andy,4500
Justin,3500
Betral,4000
一、定义自定义强类型聚合函数
package com.cjs
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
//定义输入数据类型
case class Employee(name:String, salary:Long)
//定义聚合缓冲器类型
case class Average(var sum:Long, var count:Long)
//继承Aggregator类时需要指定泛型类型,依次为:传入聚合缓冲器的数据类型、聚合缓冲器的类型、返回结果的类型
object MyAggregator extends Aggregator[Employee, Average, Double]{
//类似于初始化聚合缓冲器
override def zero: Average = Average(0L,0L)
//根据传入的参数进行运算操作,最后更新buffer缓冲器,并返回
override def reduce(buffer: Average, a: Employee): Average = {
buffer.sum += a.salary
buffer.count +=1
buffer
}
//b1为主缓冲器,b2为分布式架构中各个节点的缓冲器,对b1和b2的数据进行运算,并返回b1
override def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
//使用主缓冲器的数据进行运算,返回一个运算结果
override def finish(reduction: Average): Double = {
reduction.sum.toDouble/reduction.count
}
//指定中间值的编码器类型
override def bufferEncoder: Encoder[Average] = {
Encoders.product
}
//指定最终输出值的编码器类型
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}
二、使用强类型聚合函数
package com.cjs
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object TestMyAggregator {
case class Emp(name:String, salary:Long)
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val conf = new SparkConf()
.set("spark.sql.warehouse.dir","file:///e:/tmp/spark-warehouse")
.set("spark.some.config.option","some-value")
val ss = SparkSession.builder()
.config(conf)
.appName("test_myAggregator")
.master("local[2]")
.getOrCreate()
val path = "E:\\IntelliJ Idea\\sparkSql_practice\\src\\main\\scala\\com\\cjs\\employee.txt"
val sc = ss.sparkContext
import ss.implicits._
val empRDD = sc.textFile(path).map(_.split(",")).map(value=>Emp(value(0),value(1).toLong))
val ds = empRDD.toDF().as[Employee]
println("DS结构:")
ds.printSchema()
println("DS数据")
ds.show()
val averSalary = MyAggregator.toColumn.name("aver_salary") //转换成Column
val result = ds.select(averSalary)
println("平均工资:")
result.show()
println("DS使用select:")
ds.select($"name",$"salary").show()
}
}
输出结果: