Spark SQL UDF和UDAF示例

Spark SQL UDF和UDAF

/**
  * scala代码
  */
package com.tom.spark.sql

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

/**
  * UDF:User Defined Function, 用户自定义的函数,函数的输入是一条具体的数据记录,实现上讲就是普通的scala函数;
  * UDAF:User Defined Aggregation Function, 用户自定义的聚合函数,函数本身作用于数据集合,能够在聚合操作的基础上进行自定义操作;
  * 实质上讲,例如说UDF会被Spark SQL中的catalyst封装成为expression,最终会通过eval方法来计算输入的输入Row,此处的Row和DataFrame
  * 中的Row没有任何关系
  */
object SparkSQLUDFUDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[4]").setAppName("SparkSQLUDFUDAF")
    val sc = new SparkContext(conf)

    val sqlContext = new SQLContext(sc)

    //模拟实际使用的数据
    val bigData = Array("Spark", "Spark", "Hadoop", "spark", "Hadoop", "spark", "Hadoop", "Hadoop", "spark", "spark")

    /**
      * 基于提供的数据创建DataFrame
      */
    val bigDataRdd = sc.parallelize(bigData)
    val bigDataRDDRow = bigDataRdd.map(item => {Row(item)})

    val structType =  StructType(Array(
      new StructField("word", StringType, true)
    ))
    val bigDataDF = sqlContext.createDataFrame(bigDataRDDRow, structType)

    bigDataDF.registerTempTable("bigDataTable") //注册成为临时表

    /**
      * 通过SQLContext注册UDF,在Scala 2.10.x版本UDF函数最多可以接收22个输入参数
      */
    sqlContext.udf.register("computeLength", (input: String) => input.length)

    //直接在sql中使用udf,就像使用SQL自带的内部函数一样
    sqlContext.sql("select word, computeLength(word) as length from bigDataTable").show

    sqlContext.udf.register("wordcount", new MyUDAF)

    sqlContext.sql("select word, wordcount(word) as count,computeLength(word) as length " +
      "from bigDataTable group by word").show

//    while(true){}

  }
}

/**
  * 按照模板实现UDAF
  */
class MyUDAF extends UserDefinedAggregateFunction {
  /**
    * 该方法指定具体输入数据的类型
    * @return
    */
  override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))

  /**
    * 在进行聚合操作的时候所要处理的数据的结果的类型
    * @return
    */
  override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))

  /**
    * 指定UDAF函数计算后返回的结果类型
    * @return
    */
  override def dataType: DataType = IntegerType

  /**
    * 确保一致性,一般都用true
    * @return
    */
  override def deterministic: Boolean = true

  /**
    * 在Aggregate之前每组数据的初始化结果
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0 }

  /**
    * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
    * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
    * @param buffer
    * @param input
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  /**
    * 最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
    * @param buffer1
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  /**
    * 返回UDAF最后的计算结果
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
posted @ 2017-10-18 17:47  柚子=_=  阅读(165)  评论(0编辑  收藏  举报