SPARK SQL中自定义udf,udaf函数统计uv(使用bitmap)

在实际工作中统计uv时,一般会使用count(distinct userId)的方式去统计人数,但这样效率不高,假设你是统计多个维度的数据,当某天你想要上卷维度,此时又需要从原始层开始统计,如果数据量大的时候将会耗费很多时间,此时便可以使用最细粒度的聚合结果进行上卷统计,即需要自定义聚合函数进行统计,将bitmap序列化为一个字节数组。

1)一次聚合

package org.shydow.UDF

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.roaringbitmap.RoaringBitmap
/**
 * @author shydow
 * @date 2021/12/13 22:55
 */

class BitmapGenUDAF extends Aggregator[Int, Array[Byte], Array[Byte]] {

  override def zero: Array[Byte] = {
    // 构造一个空的bitmap
    val bm: RoaringBitmap = RoaringBitmap.bitmapOf()
    // 将bitmap序列化为字节数组
    BitmapUtil.serBitmap(bm)
  }

  override def reduce(b: Array[Byte], a: Int): Array[Byte] = {
    // 将buff反序列化为bitmap
    val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(b)
    bitmap.add(a)
    BitmapUtil.serBitmap(bitmap)
  }

  override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = {
    val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1)
    val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2)
    bitmap1.or(bitmap2)
    BitmapUtil.serBitmap(bitmap1)
  }

  override def finish(reduction: Array[Byte]): Array[Byte] = reduction

  override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY

  override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY
}
package org.shydow.UDF

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/13 22:45
 */
object BitmapUtil {

  /**
   * 序列化bitmap
   */
  def serBitmap(bm: RoaringBitmap): Array[Byte] = {
    val stream = new ByteArrayOutputStream()
    val dataOutput = new DataOutputStream(stream)
    bm.serialize(dataOutput)
    stream.toByteArray
  }

  /**
   * 反序列bitmap
   */
  def deSerBitmap(bytes: Array[Byte]): RoaringBitmap = {
    val bm: RoaringBitmap = RoaringBitmap.bitmapOf()
    val stream = new ByteArrayInputStream(bytes)
    val inputStream = new DataInputStream(stream)
    bm.deserialize(inputStream)
    bm
  }
}
package org.shydow.UDF

import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/13 22:25
 */

object TestBehaviorAnalysis {
  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder()
      .appName("analysis")
      .master("local[*]")
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    import spark.implicits._

    val schema = StructType(Seq(
      StructField("id", LongType),
      StructField("eventType", StringType),
      StructField("code", StringType),
      StructField("timestamp", LongType))
    )
    val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv")
    frame.createOrReplaceTempView("order_log")

    /**
     * 使用distinct count 计算uv
     */
    spark.sql(
      s"""
         |select
         |  eventType,
         |  count(1) as pv,
         |  count(distinct id) as uv
         |from order_log
         |group by eventType
         |""".stripMargin).show()

    /**
     * 自定义UDAF计算uv
     */
    import org.apache.spark.sql.functions.udaf
    spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF))  // 这个函数出来的是字节数组,如果要计算具体的基数得写一个udf

    def card(byteArray: Array[Byte]): Int = {
      val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray)
      bitmap.getCardinality
    }
    spark.udf.register("get_card", card _)

    spark.sql(
      s"""
         |select
         |  eventType,
         |  count(1) as pv,
         |  gen_bitmap(id) as uv_arr,
         |  get_card(gen_bitmap(id)) as uv
         |from order_log
         |group by eventType
         |""".stripMargin).show()

    spark.close()
  }
}

 

2)上卷聚合

package org.shydow.UDF

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/14 8:36
 */

class BitmapOrMergeUDAF extends Aggregator[Array[Byte], Array[Byte], Array[Byte]]{
  override def zero: Array[Byte] = {
    val bitmap: RoaringBitmap = RoaringBitmap.bitmapOf()
    BitmapUtil.serBitmap(bitmap)
  }

  override def reduce(b: Array[Byte], a: Array[Byte]): Array[Byte] = {
    val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b)
    val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(a)
    bitmap1.or(bitmap2)
    BitmapUtil.serBitmap(bitmap1)
  }

  override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = {
    val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1)
    val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2)
    bitmap1.or(bitmap2)
    BitmapUtil.serBitmap(bitmap1)
  }

  override def finish(reduction: Array[Byte]): Array[Byte] = reduction

  override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY

  override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY
}
package org.shydow.UDF

import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/13 22:25
 */

object TestBehaviorAnalysis {
  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder()
      .appName("analysis")
      .master("local[*]")
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    import spark.implicits._

    val schema = StructType(Seq(
      StructField("id", LongType),
      StructField("eventType", StringType),
      StructField("code", StringType),
      StructField("timestamp", LongType))
    )
    val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv")
    frame.createOrReplaceTempView("order_log")

    /**
     * 使用distinct count 计算uv
     */
    spark.sql(
      s"""
         |select
         |  eventType,
         |  code,
         |  count(1) as pv,
         |  count(distinct id) as uv
         |from order_log
         |where code is not null
         |group by eventType, code
         |""".stripMargin).show()

    /**
     * 自定义UDAF计算uv
     */
    import org.apache.spark.sql.functions.udaf
    spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF))  // 这个函数出来的是字节数组,如果要计算具体的基数得写一个udf

    def card(byteArray: Array[Byte]): Int = {
      val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray)
      bitmap.getCardinality
    }
    spark.udf.register("get_card", card _)

    val res: DataFrame = spark.sql(
      s"""
         |select
         |  eventType,
         |  code,
         |  count(1) as pv,
         |  gen_bitmap(id) as uv_arr,
         |  get_card(gen_bitmap(id)) as uv
         |from order_log
         |where code is not null
         |group by eventType, code
         |""".stripMargin)
    res.createTempView("dws_stat")

    spark.udf.register("bitmapOr", udaf(new BitmapOrMergeUDAF))
    spark.sql(
      s"""
        |select
        | eventType,
        | sum(pv) as total_pv,
        | bitmapOr(uv_arr),
        | get_card(bitmapOr(uv_arr)) as total_uv
        |from dws_stat
        |group by eventType
        |""".stripMargin).show()


    spark.close()
  }
}

 

posted @ 2021-12-14 08:29  Shydow  阅读(1199)  评论(0编辑  收藏  举报