SPARK SQL中自定义udf,udaf函数统计uv(使用bitmap)
在实际工作中统计uv时,一般会使用count(distinct userId)的方式去统计人数,但这样效率不高,假设你是统计多个维度的数据,当某天你想要上卷维度,此时又需要从原始层开始统计,如果数据量大的时候将会耗费很多时间,此时便可以使用最细粒度的聚合结果进行上卷统计,即需要自定义聚合函数进行统计,将bitmap序列化为一个字节数组。
1)一次聚合
package org.shydow.UDF import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.expressions.Aggregator import org.roaringbitmap.RoaringBitmap /** * @author shydow * @date 2021/12/13 22:55 */ class BitmapGenUDAF extends Aggregator[Int, Array[Byte], Array[Byte]] { override def zero: Array[Byte] = { // 构造一个空的bitmap val bm: RoaringBitmap = RoaringBitmap.bitmapOf() // 将bitmap序列化为字节数组 BitmapUtil.serBitmap(bm) } override def reduce(b: Array[Byte], a: Int): Array[Byte] = { // 将buff反序列化为bitmap val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(b) bitmap.add(a) BitmapUtil.serBitmap(bitmap) } override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = { val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1) val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2) bitmap1.or(bitmap2) BitmapUtil.serBitmap(bitmap1) } override def finish(reduction: Array[Byte]): Array[Byte] = reduction override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY }
package org.shydow.UDF import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.roaringbitmap.RoaringBitmap /** * @author shydow * @date 2021/12/13 22:45 */ object BitmapUtil { /** * 序列化bitmap */ def serBitmap(bm: RoaringBitmap): Array[Byte] = { val stream = new ByteArrayOutputStream() val dataOutput = new DataOutputStream(stream) bm.serialize(dataOutput) stream.toByteArray } /** * 反序列bitmap */ def deSerBitmap(bytes: Array[Byte]): RoaringBitmap = { val bm: RoaringBitmap = RoaringBitmap.bitmapOf() val stream = new ByteArrayInputStream(bytes) val inputStream = new DataInputStream(stream) bm.deserialize(inputStream) bm } }
package org.shydow.UDF import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn} import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} import org.roaringbitmap.RoaringBitmap /** * @author shydow * @date 2021/12/13 22:25 */ object TestBehaviorAnalysis { def main(args: Array[String]): Unit = { val spark: SparkSession = SparkSession.builder() .appName("analysis") .master("local[*]") .getOrCreate() spark.sparkContext.setLogLevel("WARN") import spark.implicits._ val schema = StructType(Seq( StructField("id", LongType), StructField("eventType", StringType), StructField("code", StringType), StructField("timestamp", LongType)) ) val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv") frame.createOrReplaceTempView("order_log") /** * 使用distinct count 计算uv */ spark.sql( s""" |select | eventType, | count(1) as pv, | count(distinct id) as uv |from order_log |group by eventType |""".stripMargin).show() /** * 自定义UDAF计算uv */ import org.apache.spark.sql.functions.udaf spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF)) // 这个函数出来的是字节数组,如果要计算具体的基数得写一个udf def card(byteArray: Array[Byte]): Int = { val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray) bitmap.getCardinality } spark.udf.register("get_card", card _) spark.sql( s""" |select | eventType, | count(1) as pv, | gen_bitmap(id) as uv_arr, | get_card(gen_bitmap(id)) as uv |from order_log |group by eventType |""".stripMargin).show() spark.close() } }
2)上卷聚合
package org.shydow.UDF import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.expressions.Aggregator import org.roaringbitmap.RoaringBitmap /** * @author shydow * @date 2021/12/14 8:36 */ class BitmapOrMergeUDAF extends Aggregator[Array[Byte], Array[Byte], Array[Byte]]{ override def zero: Array[Byte] = { val bitmap: RoaringBitmap = RoaringBitmap.bitmapOf() BitmapUtil.serBitmap(bitmap) } override def reduce(b: Array[Byte], a: Array[Byte]): Array[Byte] = { val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b) val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(a) bitmap1.or(bitmap2) BitmapUtil.serBitmap(bitmap1) } override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = { val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1) val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2) bitmap1.or(bitmap2) BitmapUtil.serBitmap(bitmap1) } override def finish(reduction: Array[Byte]): Array[Byte] = reduction override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY }
package org.shydow.UDF import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn} import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} import org.roaringbitmap.RoaringBitmap /** * @author shydow * @date 2021/12/13 22:25 */ object TestBehaviorAnalysis { def main(args: Array[String]): Unit = { val spark: SparkSession = SparkSession.builder() .appName("analysis") .master("local[*]") .getOrCreate() spark.sparkContext.setLogLevel("WARN") import spark.implicits._ val schema = StructType(Seq( StructField("id", LongType), StructField("eventType", StringType), StructField("code", StringType), StructField("timestamp", LongType)) ) val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv") frame.createOrReplaceTempView("order_log") /** * 使用distinct count 计算uv */ spark.sql( s""" |select | eventType, | code, | count(1) as pv, | count(distinct id) as uv |from order_log |where code is not null |group by eventType, code |""".stripMargin).show() /** * 自定义UDAF计算uv */ import org.apache.spark.sql.functions.udaf spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF)) // 这个函数出来的是字节数组,如果要计算具体的基数得写一个udf def card(byteArray: Array[Byte]): Int = { val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray) bitmap.getCardinality } spark.udf.register("get_card", card _) val res: DataFrame = spark.sql( s""" |select | eventType, | code, | count(1) as pv, | gen_bitmap(id) as uv_arr, | get_card(gen_bitmap(id)) as uv |from order_log |where code is not null |group by eventType, code |""".stripMargin) res.createTempView("dws_stat") spark.udf.register("bitmapOr", udaf(new BitmapOrMergeUDAF)) spark.sql( s""" |select | eventType, | sum(pv) as total_pv, | bitmapOr(uv_arr), | get_card(bitmapOr(uv_arr)) as total_uv |from dws_stat |group by eventType |""".stripMargin).show() spark.close() } }