自定义UDAF2(多进一出函数)

package SparkSQL.fun.registerfum

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row, SparkSession}

object registerfun3 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("registfun2").setMaster("local[*]")
    val session = SparkSession.builder().config(conf).getOrCreate()

    import session.implicits._
    val dataset: Dataset[student] = session.createDataset(Array(
      student("zs", "c001", 21, "男"),
      student("ls", "c001", 22, "女"),
      student("ww", "c001", 23, "男"),
      student("ml", "c002", 20, "女"),
      student("zb", "c002", 23, "男")
    ))
    dataset.createOrReplaceTempView("student")

    session.udf.register("clz_age_max", new MyMax())
    val frame = session.sql("select clz, clz_age_max(age) from student group by clz")
    frame.show()

    session.stop()

  }
}

class MyMax extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = {
    StructType(Array(
      StructField("input", DataTypes.LongType)
    ))
  }

  override def bufferSchema: StructType = {
    StructType(Array(
      StructField("max", DataTypes.LongType)
    ))
  }

  override def dataType: DataType = DataTypes.LongType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    var max = input.getAs[Long](0)

    if (max > buffer.getLong(0)) {
      buffer(0) = max
    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    var max: Long = 0L
    var a: Long = buffer1.getLong(0)
    var b: Long = buffer2.getLong(0)
    if (a > b) {
      max = a
    } else {
      max = b
    }
    buffer1(0) = max
  }

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Long](0)
  }
}
posted @ 2022-09-05 21:11  jsqup  阅读(36)  评论(0编辑  收藏  举报