Spark SQL的两种用户自定义聚合函数(UDAF)

一、概述

DataFrames的内置函数提供了常见的聚合函数,比如count(), countDistinct(), avg(), max(), min()等,但是这些函数是为DataFrames而设计的,Spark SQL也有适用于强类型的Datasets的类型安全的函数。此外,用户也可以自定义聚合函数。自定义聚合函数有两种类型,一种是无类型的自定义聚合函数(适用于DataFrame),另一种是安全类型的自定义聚合函数(适用于DataSet)。

二、两种UDAF的方式

1.无类型的用户UDAF

 继承UserDefinedAggregateFunction抽象类,实现无类型的自定义聚集函数    

package com.company.sparksql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object UserDefinedUntypedAggregation {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("UserDefinedUntypedAggregation")
      .master("local")
      .getOrCreate()
    Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
    Logger.getLogger("org.apache.hadoop").setLevel(Level.OFF)

    // 注册函数
    spark.udf.register("myAverage", MyAverage)
    val df = spark.read.json("file:///E:/employees.json")
    df.createOrReplaceTempView("employees")
    df.show()

    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+

    spark.stop()
  }

  /*
  * 继承UserDefinedAggregateFunction抽象类
  * 实现无类型的自定义聚集函数
  * 该函数的作用是求平均值
  */
  object MyAverage extends UserDefinedAggregateFunction {
    //聚集函数输入参数的数据类型(:: Nil表示在此之前插入空列表)
    def inputSchema: StructType = StructType(StructField("inputCol", LongType) :: Nil)

    //缓冲值的数据类型
    def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)

    //返回值的数据类型
    def dataType: DataType = DoubleType

    //相同的输入是否总是返回相同的输出
    def deterministic: Boolean = true

    //   初始化缓冲值
    def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L //初始化sum的值
      buffer(1) = 0L //初始化count的值
    }

    //根据新的输入,更新缓冲值
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      if (!input.isNullAt(0)) {
        buffer(0) = buffer.getLong(0) + input.getLong(0) //更新sum的值
        buffer(1) = buffer.getLong(1) + 1 //更新count的值
      }

    }

    //对缓冲值进行合并,将更新的值存储到buffer1
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
      buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }

    //计算最终的结果
    def evaluate(buffer: Row): Double = {
      buffer.getLong(0).toDouble / buffer.getLong(1)
    }
  }

}

  2.类型安全的UDAF

                继承Aggregator抽象类,下面的案例是求某个字段的平均值

package com.company.sparksql

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}

object UserDefinedTypedAggregation {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("UserDefinedTypedAggregation")
      .master("local")
      .getOrCreate()
    Logger.getLogger("org.apache.spark").setLevel(Level.OFF)
    Logger.getLogger("org.apache.hadoop").setLevel(Level.OFF)
    //导入隐式转换
    import spark.implicits._
    //创建DataSet
    val employeeDS = spark.read.json("file:///E:/employees.json").as[Employee]
    employeeDS.show()

    //将函数转换成`TypedColumn`,可以在DataSet中使用
    //该列的别称为average_salary
    val averageSalary = MyAverage.toColumn.name("average_salary")
    val res = employeeDS.select(averageSalary)
    res.show()

  }

  //样例类Employee,对应的json格式为:{"name":"Michael", "salary":3000}
  case class Employee(name: String, salary: Long)

  case class Average(var sum: Long, var count: Long)

  /*
  * Employee:聚合函数输入的类型
  * Average:聚合中间值的数据类型
  * Double:返回结果值的类型
  * */
  object MyAverage extends Aggregator[Employee, Average, Double] {
    //初始化初值,应满足任何 b + zero = b
    def zero: Average = Average(0L, 0L)

    //Combine two values to produce a new value.  For performance, the function may modify `b` and
    //   * return it instead of constructing new object for b.
    /**
      * 合并两个值,产生一个新值
      * 考虑到性能的问题,该函数应修改buffer的值并返回之
      *
      * @param buffer   :中间值
      * @param employee :新值
      * @return Average:返回值
      */
    def reduce(buffer: Average, employee: Employee): Average = {
      buffer.sum += employee.salary
      buffer.count += 1
      buffer
    }

    /**
      * 合并两个中间值
      *
      * @param buffer1 :中间值1
      * @param buffer2 :中间值2
      * @return Average :返回值类型
      */
    def merge(buffer1: Average, buffer2: Average): Average = {
      buffer1.sum += buffer2.sum
      buffer1.count += buffer2.count
      buffer1
    }

    /**
      * 转换合并值的输出
      *
      * @param reduction :最终的合并值
      * @return Double:返回值类型
      */
    def finish(reduction: Average): Double = {
      reduction.sum.toDouble / reduction.count
    }

    //中间值类型的编码器(Encoder)
    //该Encoder适用于Scala的product类型 (tuples, case classes,等等)
    def bufferEncoder: Encoder[Average] = Encoders.product

    //结果值类型的编码器(Encoder)
    //该Encoder适用于Scala的Double类型
    def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }

}

  三、所用的数据集

{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}

参考http://spark.apache.org/docs/latest/sql-programming-guide.html


                                          全文完

         

posted @ 2018-11-01 11:49  大数据技术与数仓  阅读(219)  评论(0编辑  收藏  举报