spark SQL (二) 聚合

      聚合内置功能DataFrames提供共同聚合,例如count()countDistinct()avg()max()min(),等。虽然这些功能是专为DataFrames,spark SQL还拥有类型安全的版本,在其中的一些 scala 和 Java使用强类型数据集的工作。而且,用户可以预定义的聚合函数,也可以创建自己自定义的聚合函数。

1, 非类型化的用户定义的聚合函数

      用户必须扩展UserDefinedAggregateFunction 抽象类来实现自定义的非类型集合函数。例如,用户定义的平均值可能如下所示:

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

object UserDefinedUntypedAggregation {

  object MyAverage extends UserDefinedAggregateFunction {
    // 这集合函数的输入参数的数据类型
    def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
    // 在聚合缓冲区中的值的数据类型
    def bufferSchema: StructType = {
      StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
    }
    // 返回值的数据类型
    def dataType: DataType = DoubleType
    // 此函数是否始终在相同的输入上返回相同的输出
    def deterministic: Boolean = true
    // 初始化给定的聚合缓冲区。缓冲区本身就是一个“Row”,除了
    // 像标准方法(例如,get(),getBoolean())检索值之外,还提供
    // 更新其值的机会。请注意,缓冲区内的数组和映射仍然是
    // 不可变的。
    def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }
    // 更新给定聚合缓冲区`与来自新的输入数据buffer``input` 
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      if (!input.isNullAt(0)) {
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1
      }
    }
    // 合并两个聚合缓冲剂和存储更新的缓冲器值回`buffer1` 
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
      buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
    }
    // 计算最终结果
    def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Spark SQL user-defined DataFrames aggregation example")
      .getOrCreate()

    // 注册函数来访问
    spark.udf.register("myAverage", MyAverage)

    val df = spark.read.json("employees.json")
    df.createOrReplaceTempView("employees")
    df.show()
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
    spark.stop()
  }
}
2,类型安全的用户定义的聚合函数
       用于强类型数据集的用户定义聚合围绕着Aggregator抽象类。例如,类型安全的用户定义的平均值可能如下所示:

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

object UserDefinedTypedAggregation {

  case class Employee(name: String, salary: Long)
  case class Average(var sum: Long, var count: Long)

  object MyAverage extends Aggregator[Employee, Average, Double] {
    // 这个聚合的零值。应满足以下性质:b + zero = b 
    def zero: Average = Average(0L, 0L)
    //合并两个值产生一个新的值。为了性能,函数可以修改`buffer` 
   //并返回它,而不是构造一个新的对象
    def reduce(buffer: Average, employee: Employee): Average = {
      buffer.sum += employee.salary
      buffer.count += 1
      buffer
    }
    // 合并两个中间值
    def merge(b1: Average, b2: Average): Average = {
      b1.sum += b2.sum
      b1.count += b2.count
      b1
    }
    // 变换还原的输出
    def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
    // 指定中间值类型的
    def bufferEncoder: Encoder[Average] = Encoders.product
    // 指定最终输出值类型的
    def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  }
  // $example off:typed_custom_aggregation$

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("Spark SQL user-defined Datasets aggregation example")
      .getOrCreate()

    import spark.implicits._

    val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
    ds.show()
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    //将函数转换为“TypedColumn”,并给它一个名称
    val averageSalary = MyAverage.toColumn.name("average_salary")
    val result = ds.select(averageSalary)
    result.show()
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
    spark.stop()
  }
}








posted @ 2017-12-23 16:48  zhou_jun  阅读(785)  评论(0编辑  收藏  举报