Spark SQL UDF 函数(四)
在 Spark
中,也支持Hive
中的自定义函数。自定义函数大致可以分为三种:
UDF(User-Defined-Function)
:即最基本的自定义函数,类似to_char,to_date
等UDAF(User- Defined Aggregation Funcation)
:用户自定义聚合函数,类似在group by
之后使用的sum,avg
等UDTF(User-Defined Table-Generating Functions)
:用户自定义生成函数,有点像stream
里面的flatMap
1. 初步使用 UDF 函数
scala> val df = spark.read.json("hdfs://hadoop1:9000/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]
// 注册使用,toUpper 为函数名称
scala> spark.udf.register("toUpper", (s: String) => s.toUpperCase)
res15: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
scala> df.createOrReplaceTempView("people")
scala> spark.sql("select toUpper(name), age from people").show
+-----------------+----+
|UDF:toUpper(name)| age|
+-----------------+----+
| MICHAEL|null|
| ANDY| 30|
| JUSTIN| 19|
+-----------------+----+
2. 自定义UDAF 聚合函数
package top.midworld.spark1031.create_df
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
// 样例类
case class UserInfo(name: String, age: Double)
object UDF1 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("udf").master("local[2]").getOrCreate()
val sc = spark.sparkContext
import spark.implicits._
val rdd = sc.textFile("hdfs://hadoop1:9000/people.txt").
map(_.split(",")).
map(x => UserInfo(x(0), x(1).trim.toDouble))
val df = rdd.toDF()
df.createOrReplaceTempView("user")
// 注册 udf 函数
spark.udf.register("mySum", new MySum)
spark.sql("select mySum(age) as age_sum from user").show()
df.show()
sc.stop()
spark.stop()
}
}
class MySum extends UserDefinedAggregateFunction {
// 输入的数据类型:29/30/19
override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
// 缓冲区的类型
override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)
// 最终聚合解结果的类型
override def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出,始终为 true
override def deterministic: Boolean = true
// 对缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
println("initialize===>" + buffer) // initialize===>[null]
// 对缓冲区集合初始化和
buffer(0) = 0D
}
// 分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
println("update===>" + buffer)
println("input===>" + input)
/*
update===>[0.0]
update===>[0.0]
input===>[19.0]
input===>[29.0]
update===>[29.0]
input===>[30.0]
*/
// 模式匹配输入数据类型
input match {
// double 类型
case Row(age: Double) =>
buffer(0) = buffer.getDouble(0) + age
// 其他类型
case _ =>
}
}
// 分区间聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
println("merge buffer1 ===>" + buffer1)
println("merge buffer2 ===>" + buffer2)
/*
merge buffer1 ===>[0.0]
merge buffer2 ===>[59.0]
merge buffer1 ===>[59.0]
merge buffer2 ===>[19.0]
*/
// buffer1 + buffer2
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
}
// 返回最终的输出值
override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}
运行结果:
+-------+
|age_sum|
+-------+
| 78.0|
+-------+
+-------+----+
| name| age|
+-------+----+
|Michael|29.0|
| Andy|30.0|
| Justin|19.0|
+-------+----+
求平均值
class MyAvg extends UserDefinedAggregateFunction {
// 输入的数据类型:29/30/19
override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
// 缓冲区的类型
override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)
// 最终聚合解结果的类型
override def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出,始终为 true
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0D
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
input match {
case Row(age: Double) =>
buffer(0) = buffer.getDouble(0) + age
buffer(1) = buffer.getLong(1) + 1L
case _ =>
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer2 match {
case Row(sum: Double, count: Long) =>
buffer1(0) = buffer1.getDouble(0) + sum
buffer1(1) = buffer2.getLong(1) + count
}
}
override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
}
3. 开窗函数
https://blog.csdn.net/sunxiaoju/article/details/103800028
https://blog.csdn.net/liangzelei/article/details/80608302?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-4.no_search_link&spm=1001.2101.3001.4242.3