Flink-UDF
Flink 的 Table API 和 SQL 提供了多种自定义函数的接口,以抽象类的形式定义。当前 UDF主要有以下几类:
- 标量函数(Scalar Functions):将输入的标量值转换成一个新的标量值;
- 表函数(Table Functions):将标量值转换成一个或多个新的行数据,也就是扩展成一个表;
- 聚合函数(Aggregate Functions):将多行数据里的标量值转换成一个新的标量值;
- 表聚合函数(Table Aggregate Functions):将多行数据里的标量值转换成一个或多个新的行数据。
1.整体调用流程
要想在代码中使用自定义的函数,我们需要首先自定义对应 UDF 抽象类的实现,并在表环境中注册这个函数,然后就可以在 Table API 和 SQL 中调用了。
(1)注册函数
注册函数时需要调用表环境的 createTemporarySystemFunction()方法,传入注册的函数名以及 UDF类的 Class 对象:
// 注册函数
tableEnv.createTemporarySystemFunction("MyFunction", classOf[MyFunction])
我们自定义的 UDF 类叫作 MyFunction,它应该是上面四种 UDF 抽象类中某一个的具体实现;在环境中将它注册为名叫 MyFunction 的函数。
这里 createTemporarySystemFunction()方法的意思是创建了一个“临时系统函数”,所以MyFunction 函 数 名 是 全 局 的 , 可 以 当 作 系 统 函 数 来 使 用 ; 我 们 也 可 以 用createTemporaryFunction()方法,注册的函数就依赖于当前的数据库(database)和目录(catalog)了,所以这就不是系统函数,而是“目录函数”(catalog function),它的完整名称应该包括所属的 database 和 catalog。
一般情况下,我们直接用 createTemporarySystemFunction()方法将 UDF 注册为系统函数就可以了。
(2)使用 Table API 调用函数
在 Table API 中,需要使用 call()方法来调用自定义函数:
tableEnv.from("MyTable").select(call("MyFunction", $("myField")))
这里 call()方法有两个参数,一个是注册好的函数名 MyFunction,另一个则是函数调用时本身的参数。这里我们定义 MyFunction 在调用时,需要传入的参数是 myField 字段。
此外,在 Table API 中也可以不注册函数,直接用“内联”(inline)的方式调用 UDF:
tableEnv.from("MyTable").select(call(classOf[SubstringFunction],$("myField"))
区别只是在于 call()方法第一个参数不再是注册好的函数名,而直接就是函数类的 Class对象了。
(3)在 SQL 中调用函数
当我们将函数注册为系统函数之后,在 SQL 中的调用就与内置系统函数完全一样了:
tableEnv.sqlQuery("SELECT MyFunction(myField) FROM MyTable")
可见,SQL 的调用方式更加方便,我们后续依然会以 SQL 为例介绍 UDF 的用法。
2.标量函数(Scalar Functions)
一对一。
自定义标量函数可以把 0 个、 1 个或多个标量值转换成一个标量值,它对应的输入是一行数据中的字段,输出则是唯一的值。所以从输入和输出表中行数据的对应关系看,标量函数是“一对一”的转换。
想要实现自定义的标量函数,我们需要自定义一个类来继承抽象类 ScalarFunction,并实现叫作 eval() 的求值方法。标量函数的行为就取决于求值方法的定义,它必须是公有的(public),而且名字必须是 eval。求值方法 eval()可以重载多次,任何数据类型都可作为求值方法的参数和返回值类型。
这里需要特别说明的是,ScalarFunction 抽象类中并没有定义 eval()方法,所以我们不能直接在代码中重写(override);但 Table API 的框架底层又要求了求值方法必须名字为 eval()。这是 Table API 和 SQL 目前还显得不够完善的地方,未来的版本应该会有所改进。
ScalarFunction 以及其它所有的 UDF 接口,都在 org.apache.flink.table.functions 中。
下面我们来看一个具体的例子。我们实现一个自定义的哈希(hash)函数 HashFunction,用来求传入对象的哈希值。
package com.zhen.flink.table import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment import org.apache.flink.table.functions.ScalarFunction /** * @Author FengZhen * @Date 10/17/22 3:52 PM * @Description TODO */ object UdfTest_ScalarFunction { def main(args: Array[String]): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) // 创建表环境 val tableEnv = StreamTableEnvironment.create(env) // 在创建表的DDL中指定时间属性字段 tableEnv.executeSql("CREATE TABLE eventTable (" + " uid STRING," + " url STRING," + " ts BIGINT," + " et AS TO_TIMESTAMP( FROM_UNIXTIME(ts/1000))," + " WATERMARK FOR et AS et - INTERVAL '3' SECOND " + ") WITH (" + " 'connector' = 'filesystem'," + " 'path' = '/Users/FengZhen/Desktop/accumulate/0_project/flink_learn/src/main/resources/data/input/clicks.txt', " + " 'format' = 'csv' " + ")") // 2.注册标量函数 tableEnv.createTemporarySystemFunction("MyHash", classOf[MyHash]) // 3.调用函数进行查询转换 val resultTable = tableEnv.sqlQuery("SELECT uid, MyHash(uid) FROM eventTable") /** * 4.得到的结果表打印输出 * +I[Mary, 2390779] * +I[Bob, 66965] * +I[Alice, 63350368] * +I[Mary, 2390779] * +I[Bob, 66965] */ tableEnv.toDataStream(resultTable).print() env.execute() } //实现自定义标量函数,自定义哈希函数 class MyHash extends ScalarFunction{ def eval(str: String): Int = { str.hashCode } } }
3.表函数(Table Functions)
一对多,字段扩展成表。
跟标量函数一样,表函数的输入参数也可以是 0 个、1 个或多个标量值;不同的是,它可以返回任意多行数据。“多行数据”事实上就构成了一个表,所以“表函数”可以认为就是返回一个表的函数,这是一个“一对多”的转换关系。之前我们介绍过的窗口 TVF,本质上就是表函数。
类似地,要实现自定义的表函数,需要自定义类来继承抽象类 TableFunction,内部必须要实现的也是一个名为 eval 的求值方法。与标量函数不同的是,TableFunction 类本身是有一个泛型参数T 的,这就是表函数返回数据的类型;而 eval()方法没有返回类型,内部也没有 return语句,是通过调用 collect()方法来发送想要输出的行数据的。多么熟悉的感觉——回忆一下DataStream API 中的 FlatMapFunction 和 ProcessFunction,它们的 flatMap 和 processElement 方法也没有返回值,也是通过 out.collect()来向下游发送数据的。
我们使用表函数,可以对一行数据得到一个表,这和 Hive 中的 UDTF 非常相似。那对于原先输入的整张表来说,又该得到什么呢?一个简单的想法是,就让输入表中的每一行,与它转换得到的表进行联结(join),然后再拼成一个完整的大表,这就相当于对原来的表进行了扩展。在 Hive 的 SQL 语法中,提供了“侧向视图”(lateral view,也叫横向视图)的功能,可以将表中的一行数据拆分成多行;Flink SQL 也有类似的功能,是用 LATERAL TABLE 语法来实现的。
在 SQL 中调用表函数,需要使用 LATERAL TABLE(<TableFunction>)来生成扩展的“侧向表”,然后与原始表进行联结(join)。这里的 join 操作可以是直接做交叉联结(cross join),在FROM 后用逗号分隔两个表就可以;也可以是以 ON TRUE 为条件的左联结(LEFT JOIN)。
下面是表函数的一个具体示例。我们实现了一个分隔字符串的函数 SplitFunction,可以将一个字符串转换成(字符串,长度)的二元组。
package com.zhen.flink.table import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.annotation.{DataTypeHint, FunctionHint} import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment import org.apache.flink.table.functions.TableFunction import org.apache.flink.types.Row /** * @Author FengZhen * @Date 10/17/22 4:07 PM * @Description TODO */ object UdfTest_TableFunction { def main(args: Array[String]): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) // 创建表环境 val tableEnv = StreamTableEnvironment.create(env) // 在创建表的DDL中指定时间属性字段 tableEnv.executeSql("CREATE TABLE eventTable (" + " uid STRING," + " url STRING," + " ts BIGINT," + " et AS TO_TIMESTAMP( FROM_UNIXTIME(ts/1000))," + " WATERMARK FOR et AS et - INTERVAL '3' SECOND " + ") WITH (" + " 'connector' = 'filesystem'," + " 'path' = '/Users/FengZhen/Desktop/accumulate/0_project/flink_learn/src/main/resources/data/input/clicks.txt', " + " 'format' = 'csv' " + ")") // 2.注册表函数 tableEnv.createTemporarySystemFunction("MySplit", classOf[MySplit]) // 3.调用函数进行查询转换 val resultTable = tableEnv.sqlQuery( """ |SELECT | uid, url, word, len |FROM eventTable, LATERAL TABLE(MySplit(url)) AS T(word, len) |""".stripMargin) /** * 4.得到的结果表打印输出 * +I[Mary, ./home, ./home, 7] * +I[Bob, ./cart, ./cart, 7] * +I[Alice, ./prod?id=1, ./prod, 7] * +I[Alice, ./prod?id=1, id=1, 4] * +I[Mary, ./prod?id=2, ./prod, 7] * +I[Mary, ./prod?id=2, id=2, 4] * +I[Bob, ./prod?id=3, ./prod, 7] * +I[Bob, ./prod?id=3, id=3, 4] */ tableEnv.toDataStream(resultTable).print() env.execute() } // 实现自定义表函数,按照问号分隔URL字段 // 注意这里的类型标注,输出是 Row 类型,Row 中包含两个字段:word 和 length。 @FunctionHint(output = new DataTypeHint("ROW<word STRING, length INT>")) class MySplit extends TableFunction[Row] { def eval(str: String){ str.split("\\?").foreach( s => collect(Row.of(s, Int.box(s.length)))) } } }
4.聚合函数(Aggregate Functions)
多对一。
用户自定义聚合函数(User Defined AGGregate function,UDAGG)会把一行或多行数据(也就是一个表)聚合成一个标量值。这是一个标准的“多对一”的转换。
聚合函数的概念我们之前已经接触过多次,如 SUM()、MAX()、MIN()、AVG()、COUNT()都是常见的系统内置聚合函数。而如果有些需求无法直接调用系统函数解决,我们就必须自定义聚合函数来实现功能了。
自定义聚合函数需要继承抽象类 AggregateFunction。AggregateFunction 有两个泛型参数<T, ACC>,T 表示聚合输出的结果类型,ACC 则表示聚合的中间状态类型。Flink SQL 中的聚合函数的工作原理如下:
(1)首先,它需要创建一个累加器(accumulator),用来存储聚合的中间结果。这与DataStream API 中的 AggregateFunction 非常类似,累加器就可以看作是一个聚合状态。调用createAccumulator()方法可以创建一个空的累加器。
(2)对于输入的每一行数据,都会调用 accumulate()方法来更新累加器,这是聚合的核心过程。
(3)当所有的数据都处理完之后,通过调用 getValue()方法来计算并返回最终的结果。所以,每个 AggregateFunction 都必须实现以下几个方法:
- createAccumulator()
这是创建累加器的方法。没有输入参数,返回类型为累加器类型 ACC。
- accumulate()
这是进行聚合计算的核心方法,每来一行数据都会调用。它的第一个参数是确定的,就是当前的累加器,类型为 ACC,表示当前聚合的中间状态;后面的参数则是聚合函数调用时传入的参数,可以有多个,类型也可以不同。这个方法主要是更新聚合状态,所以没有返回类型。
需要注意的是,accumulate()与之前的求值方法 eval()类似,也是底层架构要求的,必须为 public,方法名必须为 accumulate,且无法直接 override、只能手动实现。
- getValue()
这是得到最终返回结果的方法。输入参数是 ACC 类型的累加器,输出类型为 T。
在遇到复杂类型时,Flink 的类型推导可能会无法得到正确的结果。所以AggregateFunction也可以专门对累加器和返回结果的类型进行声明,这是通过 getAccumulatorType()和getResultType()两个方法来指定的。
除了上面的方法,还有几个方法是可选的。这些方法有些可以让查询更加高效,有些是在某些特定场景下必须要实现的。比如,如果是对会话窗口进行聚合,merge()方法就是必须要实现的,它会定义累加器的合并操作,而且这个方法对一些场景的优化也很有用;而如果聚合函数用在 OVER 窗口聚合中,就必须实现 retract()方法,保证数据可以进行撤回操作;resetAccumulator()方法则是重置累加器,这在一些批处理场景中会比较有用。
AggregateFunction 的所有方法都必须是 公有的(public),不能是静态的(static),而且名字必须跟上面写的完全一样。 createAccumulator 、 getValue 、 getResultType 以 及getAccumulatorType 这几个方法是在抽象类 AggregateFunction 中定义的,可以 override;而其他则都是底层架构约定的方法。
下面举一个具体的示例。在常用的系统内置聚合函数里,可以用 AVG()来计算平均值;如果我们现在希望计算的是某个字段的“加权平均值”,又该怎么做呢?系统函数里没有现成的实现,所以只能自定义一个聚合函数 WeightedAvg 来计算了。
比如我们要从学生的分数表 ScoreTable 中计算每个学生的加权平均分。为了计算加权平均值,应该从输入的每行数据中提取两个值作为参数:要计算的分数值 score,以及它的权重weight。而在聚合过程中,累加器(accumulator)需要存储当前的加权总和 sum,以及目前数据 的 个 数 count 。这可以用一个二元组来表示,也可 以 单 独 定 义 一 个 类WeightedAvgAccumulator,里面包含 sum 和 count 两个属性,用它的对象实例来作为聚合的累加器。
package com.zhen.flink.table import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment import org.apache.flink.table.functions.AggregateFunction /** * @Author FengZhen * @Date 10/17/22 4:28 PM * @Description TODO */ object UdfTest_AggregateFunction { def main(args: Array[String]): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) // 创建表环境 val tableEnv = StreamTableEnvironment.create(env) // 在创建表的DDL中指定时间属性字段 tableEnv.executeSql("CREATE TABLE eventTable (" + " uid STRING," + " url STRING," + " ts BIGINT," + " et AS TO_TIMESTAMP( FROM_UNIXTIME(ts/1000))," + " WATERMARK FOR et AS et - INTERVAL '3' SECOND " + ") WITH (" + " 'connector' = 'filesystem'," + " 'path' = '/Users/FengZhen/Desktop/accumulate/0_project/flink_learn/src/main/resources/data/input/clicks.txt', " + " 'format' = 'csv' " + ")") // 2.注册聚合函数 tableEnv.createTemporarySystemFunction("WeightedAvg", classOf[WeightedAvg]) // 3.调用函数进行查询转换 val resultTable = tableEnv.sqlQuery( """ |SELECT | uid, WeightedAvg(ts, 1) AS avg_ts |FROM eventTable |GROUP BY uid |""".stripMargin) /** * 4.得到的结果表打印输出 * +I[Mary, 1000] * +I[Bob, 2000] * +I[Alice, 3000] * -U[Mary, 1000] * +U[Mary, 2500] * -U[Bob, 2000] * +U[Bob, 3500] */ tableEnv.toChangelogStream(resultTable).print() env.execute() } // 单独定义样例类,用来表示就和过程中累加器的类型 case class WeightedAvgAccumulator(var sum: java.lang.Long = 0L, var count: Int = 0){} // 实现自定义的聚合函数,计算加权平均数 class WeightedAvg extends AggregateFunction[java.lang.Long, WeightedAvgAccumulator] { override def getValue(accumulator: WeightedAvgAccumulator): java.lang.Long = { if (accumulator.count == 0){ null } else{ accumulator.sum / accumulator.count } } override def createAccumulator(): WeightedAvgAccumulator = { WeightedAvgAccumulator() } // 每来一条数据,都会调用 def accumulate(accumulator: WeightedAvgAccumulator, iValue: java.lang.Long, iWeight: Int): Unit ={ accumulator.sum = accumulator.sum + (iValue * iWeight) accumulator.count = accumulator.count+ iWeight } } }
5.表聚合函数(Table Aggregate Functions)
多对多,多条数据聚合后生成表。
用户自定义表聚合函数(UDTAGG)可以把一行或多行数据(也就是一个表)聚合成另一张表,结果表中可以有多行多列。很明显,这就像表函数和聚合函数的结合体,是一个“多对多”的转换。
自定义表聚合函数需要继承抽象类 TableAggregateFunction。TableAggregateFunction 的结
构和原理与 AggregateFunction 非常类似,同样有两个泛型参数<T, ACC>,用一个 ACC 类型的累加器(accumulator)来存储聚合的中间结果。聚合函数中必须实现的三个方法,在TableAggregateFunction 中也必须对应实现:
- createAccumulator()
创建累加器的方法,与 AggregateFunction 中用法相同。
- accumulate()
聚合计算的核心方法,与 AggregateFunction 中用法相同。
- emitValue()
所有输入行处理完成后,输出最终计算结果的方法。这个方法对应着 AggregateFunction中的 getValue()方法;区别在于 emitValue 没有输出类型,而输入参数有两个:第一个是 ACC类型的累加器,第二个则是用于输出数据的“收集器”out,它的类型为 Collect<T>。所以很明显,表聚合函数输出数据不是直接 return,而是调用 out.collect()方法,调用多次就可以输出多行数据了;这一点与表函数非常相似。另外,emitValue()在抽象类中也没有定义,无法 override,必须手动实现。
表聚合函数得到的是一张表;在流处理中做持续查询,应该每次都会把这个表重新计算输出。如果输入一条数据后,只是对结果表里一行或几行进行了更新(Update),这时我们重新计算整个表、全部输出显然就不够高效了。为了提高处理效率,TableAggregateFunction 还提供了一个 emitUpdateWithRetract()方法,它可以在结果表发生变化时,以“撤回”(retract)老数 据 、 发 送 新 数 据 的 方 式 增 量 地 进 行 更 新 。 如 果 同 时 定 义 了 emitValue() 和emitUpdateWithRetract()两个方法,在进行更新操作时会优先调用 emitUpdateWithRetract()。
表聚合函数相对比较复杂,它的一个典型应用场景就是 Top N 查询。比如我们希望选出一组数据排序后的前两名,这就是最简单的 TOP-2 查询。没有现成的系统函数,那么我们就可以自定义一个表聚合函数来实现这个功能。在累加器中应该能够保存当前最大的两个值,每当来一条新数据就在 accumulate()方法中进行比较更新,最终在 emitValue()中调用两次out.collect()将前两名数据输出。
package com.zhen.flink.table import java.sql.Timestamp import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment import org.apache.flink.table.functions.TableAggregateFunction import org.apache.flink.util.Collector import org.apache.flink.table.api.Expressions.{$, call} /** * @Author FengZhen * @Date 10/18/22 9:59 PM * @Description TODO */ object UdfTest_TableAggFunction { def main(args: Array[String]): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment env.setParallelism(1) // 创建表环境 val tableEnv = StreamTableEnvironment.create(env) // 在创建表的DDL中指定时间属性字段 tableEnv.executeSql("CREATE TABLE eventTable (" + " uid STRING," + " url STRING," + " ts BIGINT," + " et AS TO_TIMESTAMP( FROM_UNIXTIME(ts/1000))," + " WATERMARK FOR et AS et - INTERVAL '3' SECOND " + ") WITH (" + " 'connector' = 'filesystem'," + " 'path' = '/Users/FengZhen/Desktop/accumulate/0_project/flink_learn/src/main/resources/data/input/clicks.txt', " + " 'format' = 'csv' " + ")") // 2.注册表聚合函数 tableEnv.createTemporarySystemFunction("Top2", classOf[Top2]) // 3.调用函数进行查询转换 // 首先进行窗口聚合得到count值,统计每个用户的访问量 val urlCountWindowTable = tableEnv.sqlQuery( """ |SELECT uid, COUNT(url) AS cnt, window_start AS w_start, window_end AS w_end |FROM TABLE( | TUMBLE(TABLE eventTable, DESCRIPTOR(et), INTERVAL '1' HOUR) |) |GROUP BY uid, window_start, window_end | |""".stripMargin) tableEnv.createTemporaryView("urlCountWindowTable", urlCountWindowTable) // 使用Table API调用表聚合函数 val resultTable = urlCountWindowTable.groupBy($("w_end")) .flatAggregate(call("Top2", ${"uid"},${"cnt"},${"w_start"},${"w_end"})) .select(${"uid"}, ${"rank"}, ${"cnt"},${"w_end"}) // val resultTable = tableEnv.sqlQuery( // """ // |SELECT // | Top2(uid, window_start, window_end) // |FROM urlCountWindowTable // |GROUP BY uid // |""".stripMargin) /** * 4.得到的结果表打印输出 * +I[Mary, 1, 2, 1970-01-01T09:00] * -D[Mary, 1, 2, 1970-01-01T09:00] * +I[Mary, 1, 2, 1970-01-01T09:00] * +I[Alice, 2, 1, 1970-01-01T09:00] * -D[Mary, 1, 2, 1970-01-01T09:00] * -D[Alice, 2, 1, 1970-01-01T09:00] * +I[Mary, 1, 2, 1970-01-01T09:00] * +I[Bob, 2, 2, 1970-01-01T09:00] */ tableEnv.toChangelogStream(resultTable).print() env.execute() } // 定义输出结果和中间累加器的样例类 case class Top2Result(uid: String, window_start: Timestamp, window_end: Timestamp, cnt: Long, rank: Int) case class Top2Accumulator(var maxCount: Long, var secondMaxCount: Long, var uid1: String, var uid2: String, var window_start: Timestamp, var window_end: Timestamp) // 实现自定义的表聚合函数 class Top2 extends TableAggregateFunction[Top2Result, Top2Accumulator] { override def createAccumulator(): Top2Accumulator = { Top2Accumulator(Long.MinValue, Long.MinValue, null, null, null, null) } // 每来一行数据,需要使用accumulate进行聚合统计 def accumulate(acc: Top2Accumulator, uid: String, cnt: Long, window_start: Timestamp, window_end: Timestamp): Unit ={ acc.window_start = window_start acc.window_end = window_end // 判断当前count值是否排名前两位 if(cnt > acc.maxCount){ // 名次向后顺延 acc.secondMaxCount = acc.maxCount acc.uid2 = acc.uid1 acc.maxCount = cnt acc.uid1 = uid }else if(cnt > acc.secondMaxCount){ acc.secondMaxCount = cnt acc.uid2 = uid } } // 输出结果数据 def emitValue(acc: Top2Accumulator, out: Collector[Top2Result]): Unit ={ // 判断cnt值是否为初始值,如果没有更新则直接跳过不输出 if (acc.maxCount != Long.MinValue){ out.collect(Top2Result(acc.uid1, acc.window_start, acc.window_end,acc.maxCount, 1)) } if (acc.secondMaxCount != Long.MinValue){ out.collect(Top2Result(acc.uid2, acc.window_start, acc.window_end,acc.secondMaxCount, 2)) } } } }
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 记一次.NET内存居高不下排查解决与启示