spark累加器及UDTF

数据源test.json

{"username": "zhangsan","age": 20}
{"username": "lisi","age": 18}
{"username": "wangwu","age": 16}

 强类型及弱类型

DateFrame是弱类型,意思是数据只有二维结构,没有数据类型
DateSet是强类型,数据包含数据结构,数据的二维结构和当前类之间有一个映射关系

1.自定义累加器

代码

import org.apache.spark.util.AccumulatorV2
import org.apache.spark.{SparkConf, SparkContext}

object Spark01_TestSer {
  def main(args: Array[String]): Unit = {
    //1.创建SparkConf并设置App名称
    val conf: SparkConf = new SparkConf().setAppName("SparkCoreTest").setMaster("local[*]")

    //2.创建SparkContext,该对象是提交Spark App的入口
    val sc: SparkContext = new SparkContext(conf)

    var sumAc = new MyAccumulator
    sc.register(sumAc)
    sc.makeRDD(List(("zhangsan",20),("lisi",30),("wangw",40))).foreach{
      case (name,age)=>{
        sumAc.add(age)
      }
    }
    println(sumAc.value)

    // 关闭连接
    sc.stop()
  }
}

class MyAC extends AccumulatorV2[Int,Double] {
  var sum = 0
  var count = 0

  override def isZero: Boolean = {
    sum==0 && count==0
  }

  override def copy(): AccumulatorV2[Int,Double] = {
    val myAC = new MyAC()
    myAC.sum = this.sum
    myAC.count = this.count
    myAC
  }

  override def reset(): Unit = {
    sum = 0
    count = 0
  }

  override def add(v: Int): Unit = {
    sum += v
    count += 1
  }

  override def merge(other: AccumulatorV2[Int, Double]): Unit = {
    other match {
      case myAC: MyAC=>{
        this.sum += myAC.sum
        this.count += myAC.count
      }
      case _ =>
    }
  }

  override def value: Double = {
    sum/count
  }
}

class MyAccumulator extends AccumulatorV2[Int,Double] {
  var sum = 0
  var count = 0

  override def isZero: Boolean = {
    sum==0 && count==0
  }

  override def copy(): AccumulatorV2[Int, Double] = {
    val myAccumulator = new MyAccumulator
    myAccumulator.sum = this.sum
    myAccumulator.count = this.count

    myAccumulator
  }

  override def reset(): Unit = {
    sum = 0
    count = 0
  }

  override def add(v: Int): Unit = {
    sum += v
    count += 1
  }

  override def merge(other: AccumulatorV2[Int, Double]): Unit = {
    other match {
      case o:MyAccumulator =>
        sum += o.sum
        count += o.count
      case _ =>
    }
  }

  override def value: Double = {
    sum/count
  }
}

2.自定义UDTF(弱类型,DateFrame)

代码

/*
*
* 自定义UDAF(弱类型  主要应用在SQL风格的DF查询)
*
* */

object SparkSQL05_UDAF {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf().setAppName("SparkSQL05_UDAF").setMaster("local[*]")
    val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    val df: DataFrame = sparkSession.read.json("D:\\IdeaProjects\\spark_test\\input\\test.json")
    df.createOrReplaceTempView("user")

    val myAvg1 = new MyAvg1
    sparkSession.udf.register("MyAvg1", myAvg1)

    sparkSession.sql("select MyAvg1(age) from user").show()

    sparkSession.stop()

  }
}

//自定义UDAF函数(弱类型)
class MyAvg1 extends UserDefinedAggregateFunction{

  //聚合函数的输入数据的类型
  override def inputSchema: StructType = {
    StructType(Array(StructField("age",IntegerType)))
  }

  //缓存数据的类型
  override def bufferSchema: StructType = {
    StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
  }

  //聚合函数返回的数据类型
  override def dataType: DataType = DoubleType

  //稳定性  默认不处理,直接返回true    相同输入是否会得到相同的输出
  override def deterministic: Boolean = true

  //初始化  缓存设置到初始状态
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //让缓存中年龄总和归0
    buffer(0) = 0L
    //让缓存中总人数归0
    buffer(1) = 0L
  }

  //更新缓存数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(!buffer.isNullAt(0)){
      buffer(0) = buffer.getLong(0) + input.getInt(0)
      buffer(1) = buffer.getLong(1) + 1L
    }
  }

  //分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //计算逻辑
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

3.自定义UDTF(强类型DateSet)

代码

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, SparkSession, TypedColumn}

object SparkSQL06_UDAF {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf().setAppName("SparkSQL06_UDAF").setMaster("local[*]")
    val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    import sparkSession.implicits._

    val df: DataFrame = sparkSession.read.json("D:\\IdeaProjects\\spark_test\\input\\test.json")
    df.createOrReplaceTempView("user")

    val myAvg2 = new MyAvg2

    val column: TypedColumn[User01, Double] = myAvg2.toColumn

    val ds: Dataset[User01] = df.as[User01]
    ds.select(column).show()

    sparkSession.stop()
  }
}

//输入类型的样例类
case class User01(username:String,age:Long)
//缓存类型,由于设计到buffer计算,注意添加var类型
case class AgeBuffer(var sum:Long,var count:Long)

//自定义UDAF函数(强类型)
//* @tparam IN 输入数据类型
//* @tparam BUF 缓存数据类型
//* @tparam OUT 输出结果数据类型
class MyAvg2 extends Aggregator[User01,AgeBuffer,Double] {
  //对缓存数据进行初始化
  override def zero: AgeBuffer = {
    AgeBuffer(0L,0L)
  }

  //对当前分区内数据进行聚合
  override def reduce(b: AgeBuffer, a: User01): AgeBuffer = {
    b.sum += a.age
    b.count += 1L
    b
  }

  //分区间合并
  override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  //返回计算结果
  override def finish(reduction: AgeBuffer): Double = {
    reduction.sum.toDouble/reduction.count.toDouble
  }

  //DataSet的编码以及解码器  ,用于进行序列化,固定写法
  //用户自定义Ref类型  product       系统值类型,根据具体类型进行选择
  override def bufferEncoder: Encoder[AgeBuffer] = {
    Encoders.product
  }

  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}

  

posted @ 2021-05-11 19:28  酷酷的狐狸  阅读(88)  评论(0编辑  收藏  举报