自定义UDAF函数(多对一函数)

package SparkSQL.fun.registerfum

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row, SparkSession}

/**
 * 实现计算每个班级的平均年龄
 */
object registerfun2 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("registfun2").setMaster("local[*]")
    val session = SparkSession.builder().config(conf).getOrCreate()

    import session.implicits._
    val dataset: Dataset[student] = session.createDataset(Array(
      student("zs", "c001", 21, "男"),
      student("ls", "c001", 22, "女"),
      student("ww", "c001", 23, "男"),
      student("ml", "c002", 20, "女"),
      student("zb", "c002", 23, "男")
    ))
    dataset.createOrReplaceTempView("student")

    session.udf.register("avg_clz_age", new MyAvg())
    val frame = session.sql("select clz, avg_clz_age(age) avg from student group by clz")
    frame.show()

    session.stop()
  }
}

// Spark SQL中自定义UDAF聚合函数  实现avg功能
class MyAvg extends UserDefinedAggregateFunction {
  /**
   * inputSchema 代表的是聚合函数输入的参数类型以及输入的参数的值的个数
   * @return
   */
  override def inputSchema: StructType = {
    StructType(Array(
      StructField("input", DataTypes.LongType)
    ))
  }

  /**
   * bufferSchema 代表的是聚合中间结果的类型
   * @return
   */
  override def bufferSchema: StructType = {
    StructType(Array(
      StructField("sum", DataTypes.LongType),
      StructField("count", DataTypes.LongType)
    ))
  }

  /**
   * dataType 函数返回的数据类型
   * @return
   */
  override def dataType: DataType = DataTypes.DoubleType

  /**
   * deterministic 相同的值是否返回相同的结果
   * @return
   */
  override def deterministic: Boolean = true

  /**
   * initialize  初始化方法
   * 初始化 bufferSchema定义的中间结果的缓存数据(初始化对应个数的初始值)
   * @param buffer
   */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //0索引代表的就是bufferSchema函数中定义的第一个数据类型 sum
    buffer(0) = 0L

    //1索引代表的就是bufferSchema函数中定义的第二个数据类型 count
    buffer(1) = 0L
  }

  /**
   * update函数就代表我函数被调用之后 输入了一个值  然后你如何把这个值加到缓存的结果中
   * @param buffer
   * @param input
   */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val age = input.getAs[Long](0)

    //将以前累加的sum值+现在输入的值 就是新的累加和
    buffer(0) = buffer.getLong(0) + age
    buffer(1) = buffer.getLong(1) + 1L
  }

  /**
   * Spark SQL是分布式运行 函数的话也是在每一个分区上更新值,更新完成之后,将多个分区的值合并起来
   * @param buffer1
   * @param buffer2
   */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getAs[Long](1)
  }

  /**
   * 这个是聚合函数的核心  根据Buffer汇总的结果计算最终的聚合结果,返会一个值
   * @param buffer
   * @return
   */
  override def evaluate(buffer: Row): Any = {
    val sum = buffer.getAs[Long](0).toDouble
    val count = buffer.getAs[Long](1).toDouble
    sum / count
  }
}
posted @ 2022-09-05 21:10  jsqup  阅读(41)  评论(0编辑  收藏  举报