DStream-04 Window函数的原理和源码

DStream 中 window 函数有两种,一种是普通 WindowedDStream,另外一种是针对 window聚合 优化的 ReducedWindowedDStream。

Demo

object SocketWordCountDstreamReduceByWindow {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf()
      .setAppName("SocketWordCountDstream")
      .setMaster("local[3]")
    val sparkStreamContext = new StreamingContext(sparkConf,Seconds(5))
    val sparkContext = sparkStreamContext.sparkContext
    sparkContext.setLogLevel("WARN")
    val dstream = sparkStreamContext.socketTextStream("localhost",9090)
    val v = dstream.flatMap(_.split(" "))
      .map((_,1))
        .reduceByKeyAndWindow((p:Int,c:Int)=>{
          p + c
        },Durations.seconds(15),Durations.seconds(5))
    v.foreachRDD(...)
    sparkStreamContext.start()
    sparkStreamContext.awaitTermination()
  }
}

源码

DStream

前提知识

在每个DStream 中会把每个batch 产生的 Rdd 放入Map中,也就是放到内存中。

// 保存RDD的 Map
@transient
private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]()

 private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = {
    // If RDD was already generated, then retrieve it from HashMap,
    // or else compute the RDD
	// 如果map 中就直接拿着用,没有就创建
    generatedRDDs.get(time).orElse {
      
      if (isTimeValid(time)) {
        val rddOption = createRDDWithLocalProperties(time, displayInnerRDDOps = false) {
          SparkHadoopWriterUtils.disableOutputSpecValidation.withValue(true) {
            compute(time)
          }
        }

        rddOption.foreach { case newRDD =>
          // Register the generated RDD for caching and checkpointing
          if (storageLevel != StorageLevel.NONE) {
            newRDD.persist(storageLevel)
            logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel")
          }
          if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
            newRDD.checkpoint()
            logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing")
          }
          
          // 放入Map中 
          generatedRDDs.put(time, newRDD)
        }
        rddOption
      } else {
        None
      }
    }
  }

同时也会 batch 完成的时候去清理这个Map。

private[streaming] def clearMetadata(time: Time) {
  
  //根据当前 batch time - rememberDuration 
  //time = 100030, rememberDuration = 10 , generatedRDDs = {100020->rdd}
  

  val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
  generatedRDDs --= oldRDDs.keys
  if (unpersistData) {
    logDebug(s"Unpersisting old RDDs: ${oldRDDs.values.map(_.id).mkString(", ")}")
    oldRDDs.values.foreach { rdd =>
      rdd.unpersist(false)
      // Explicitly remove blocks of BlockRDD
      rdd match {
        case b: BlockRDD[_] =>
          logInfo(s"Removing blocks of RDD $b of time $time")
          b.removeBlocks()
        case _ =>
      }
    }
  }

  dependencies.foreach(_.clearMetadata(time))
}

这个清理的过程是从后往前的,先清理 子DStream 然后是父DStream

// 默认 rememberDuration = slideDuration  也就是 batchInterval 。
private[streaming] var rememberDuration: Duration = null

//所以这边 父继承子
private[streaming] def parentRememberDuration = rememberDuration


DStream 初始化时

private[streaming] def initialize(time: Time) {

  var minRememberDuration = slideDuration
  
  // checkpointDuration 一般的DStream都是空的
  
  if (checkpointDuration != null && minRememberDuration <= checkpointDuration) {
    minRememberDuration = checkpointDuration * 2
  }
  if (rememberDuration == null || rememberDuration < minRememberDuration) {
    rememberDuration = minRememberDuration
  }

  // Initialize the dependencies
  dependencies.foreach(_.initialize(zeroTime))
}
// 抽象方法
// slideDuration 官方的注释是:方法是 DStream生成RDD的时间间隔。
// 默认情况下 slideDuration  就是 batchInterval 。针对 WindowedDStream  就是滑动的时间  但也是批次的时间
def slideDuration: Duration

这边举个例子 interval = 1s window = 4s slide = 2s。JobGenerator 会每个隔 interval = 1s 发送 GenerateJobs 事件,然后会触发 最后一个DStream 的 getOrCompute,然后依次 compute 会优先调 parent.getOrCompute 依次递归到第一个DStream。但是 getOrCompute 会判断这个 batch 的 Time - zeroTime(zeroTime 是 Stream 开始的时间,第一个batch 就是 zeroTime + batchInterval) 是不是 slideDuration 的倍数。如果是 才会调 compute 否则 就会返回 None。

private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = {
  // If RDD was already generated, then retrieve it from HashMap,
  // or else compute the RDD
  generatedRDDs.get(time).orElse {
    // Compute the RDD if time is valid (e.g. correct time in a sliding window)
    // of RDD generation, else generate nothing.
    if (isTimeValid(time)) {
		....
        rdd
      }
      rddOption
    } else {
      None
    }
  }
}

校验时间。 zeroTime 就是第一个Batch的时间

private[streaming] def isTimeValid(time: Time): Boolean = {
  if (!isInitialized) {
    throw new SparkException (this + " has not been initialized")
  } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
    logInfo(s"Time $time is invalid as zeroTime is $zeroTime" +
      s" , slideDuration is $slideDuration and difference is ${time - zeroTime}")
    false
  } else {
    logDebug(s"Time $time is valid")
    true
  }
}

例如常见 MappedDStream 是父 DStream 的 slideDuration

override def slideDuration: Duration = parent.slideDuration 

第一个DStream 的 slideDuration = batchDuration

override def slideDuration: Duration = {
  if (ssc == null) throw new Exception("ssc is null")
  if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null")
  ssc.graph.batchDuration
}

InputDStream

InputDStream(输入流就是数据源) 是 DirectKafkaInputDStream 、FileInputDStream..... 的父类。 输入流都是继承这个方法。

override def slideDuration: Duration = {
  if (ssc == null) throw new Exception("ssc is null")
  if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null")
  // 就是 new StreamContext  传入流的间隔时间
  ssc.graph.batchDuration
}

默认情况下就是 slideDuration = batchInterval 批次间隔时间

也就是 rememberDuration = batchInterval , parentRememberDuration = batchInterval 。默认只会保留上次一个batch的RDD。

进入正题。

源码的入口就是 reduceByKeyAndWindow

PairDStreamFunctions

实际就是 dstream.reduceByKey().window().reduceByKey,点开window()

def reduceByKeyAndWindow(
    reduceFunc: (V, V) => V,
    windowDuration: Duration,
    slideDuration: Duration,
    partitioner: Partitioner
  ): DStream[(K, V)] = ssc.withScope {
  self.reduceByKey(reduceFunc, partitioner)
      .window(windowDuration, slideDuration)
      .reduceByKey(reduceFunc, partitioner)
}

DStream

其实只要是DStream 都有 window 函数

def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = ssc.withScope {
  new WindowedDStream(this, windowDuration, slideDuration)
}

WindowedDStream

跟到 WindowedDStream 类

class WindowedDStream[T: ClassTag](
    parent: DStream[T],
    _windowDuration: Duration,
    _slideDuration: Duration)
  extends DStream[T](parent.ssc) {


  // Persist parent level by default, as those RDDs are going to be obviously reused.
  // 默认吧 parent dsteam 进行持久化,因为 parent.dsteam中的rdd 将会被吃
   parent.persist(StorageLevel.MEMORY_ONLY_SER)
  
  // 窗口时间
  def windowDuration: Duration = _windowDuration

  override def dependencies: List[DStream[_]] = List(parent)

  // 这边的 slideDuration 就不是parent.slideDuration 而是我们定传入 window方法的滑动间隔。
  override def slideDuration: Duration = _slideDuration
  

  // parentRememberDuration 是 slideDuration + windowDuration 
  override def parentRememberDuration: Duration = rememberDuration + windowDuration
  
  
  override def compute(validTime: Time): Option[RDD[T]] = {
    val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)
    
    // 获取一个范围内的RDD
    val rddsInWindow = parent.slice(currentWindow)
    // 最后union 范围内的RDD
    Some(ssc.sc.union(rddsInWindow))
  }
}

DStream

def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = ssc.withScope {
  if (!isInitialized) {
    throw new SparkException(this + " has not been initialized")
  }

  // windowStream.parent 就是普通的DStream slideDuration = batchInterval
  // zeroTime 是 Stream 开始的时间,第一个batch 就是 zeroTime + batchInterval
  截至时间到第一个batch的时间  是不是  batchInterval 的倍数。主要是方便计算
  val alignedToTime = if ((toTime - zeroTime).isMultipleOf(slideDuration)) {
    toTime
  } else {
    logWarning(s"toTime ($toTime) is not a multiple of slideDuration ($slideDuration)")
    toTime.floor(slideDuration, zeroTime)
  }

  // 同理 开始时间到第一个batch的时间  是不是  batchInterval 的倍数。主要是方便计算
  val alignedFromTime = if ((fromTime - zeroTime).isMultipleOf(slideDuration)) {
    fromTime
  } else {
    logWarning(s"fromTime ($fromTime) is not a multiple of slideDuration ($slideDuration)")
    fromTime.floor(slideDuration, zeroTime)
  }

  logInfo(s"Slicing from $fromTime to $toTime" +
    s" (aligned to $alignedFromTime and $alignedToTime)")

  alignedFromTime.to(alignedToTime, slideDuration).flatMap { time =>
    //最后将这一个时间段的 根据 slideDuration 来切分,然后得到之前 batch的Time
    // 然后 到DStrem getOrCompute 中,从内存中重新取回来。
    if (time >= zeroTime) getOrCompute(time) else None
  }
}

例子

这边还是举个例子 interval = 1s window = 4s slice = 2s。zeroTime = 1582800004.000

isTimeValid = (Time - zeroTime % slide 是否为整数)

1s后 ,Time = 1582800005 ,先到到 WindowedDStream isTimeValid = false None

2s后 ,Time = 1582800006 , 先到到 WindowedDStream isTimeValid = true :

1、val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)

​ 第一个参数 1582800006 - 4 + 1 = 1582800003

​ 第二个参数 1582800006

2、parent.slice(1582800003,1582800006 ) 也就是取 1582800003、1582800004、1582800005、1582800006 四个时间的RDD。1582800003、1582800004 是无效的时间 直接是None ,然后调用

getOrCompute(1582800005)和 getOrCompute(1582800006 ) ,这样就取到 1582800005 的 RDD,虽然1582800005 时刻的返回的是None 到了 1582800006 就会把1582800005 取到,依次类推就好了。

Demo 2

大家想,如果一个Window 的时间比较长,并且 reduceBykey().window().reduceBykey 涉及的计算比较慢。每次都需要重新计算 4个batch的RDD ,很浪费(前提是不是去重计算 ) 。

假设 batchInterval = 1s window = 4s sliace = 1s

这边的数字代表时间戳

第一次 计算 1、2、3、4 的RDD

第二次 计算 2、3、4、5 的RDD

第三次 计算 3、4、 5、6、 的RDD

规律就是 上一次计算的结果 可以被下一次所副复用,减少计算。怎么服用呢

第一个 1、2、3、4 减 1的RDD 加 5的RDD 就可以了。

object SocketWordCountDstreamReduceByWindowOptimization {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf()
      .setAppName("SocketWordCountDstream")
      .setMaster("local[3]")
    val sparkStreamContext = new StreamingContext(sparkConf,Seconds(1))
    val sparkContext = sparkStreamContext.sparkContext
    sparkContext.setLogLevel("WARN")
    val dstream = sparkStreamContext.socketTextStream("localhost",9090)
    val v = dstream.flatMap(_.split(" "))
      .map((_,1))
        .reduceByKeyAndWindow((p:Int,c:Int)=>{
          p + c
        },(c:Int,p:Int)=>{
          c - p
        },Durations.seconds(4),Durations.seconds(2))
    v.foreachRDD(rdd => {
      ....
    })
    sparkStreamContext.start()
    sparkStreamContext.awaitTermination()
  }
}

PairDStreamFunctions

需要传入两个关键的函数,第一个就是 减去 slide/batchinterval 个 RDD 的函数,第二个就是加上 slide/batchinterval 个 RDD 的函数。

def reduceByKeyAndWindow(
    reduceFunc: (V, V) => V,
    invReduceFunc: (V, V) => V,
    windowDuration: Duration,
    slideDuration: Duration,
    partitioner: Partitioner,
    filterFunc: ((K, V)) => Boolean
  ): DStream[(K, V)] = ssc.withScope {

  val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
  val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc)
  val cleanedFilterFunc = if (filterFunc != null) Some(ssc.sc.clean(filterFunc)) else None
  new ReducedWindowedDStream[K, V](
    self, cleanedReduceFunc, cleanedInvReduceFunc, cleanedFilterFunc,
    windowDuration, slideDuration, partitioner
  )
}

ReducedWindowedDStream

几个关键都是和window 一样

def windowDuration: Duration = _windowDuration

override def dependencies: List[DStream[_]] = List(reducedStream)

override def slideDuration: Duration = _slideDuration
// 这个就是 需要 checkpoint
override val mustCheckpoint = true

override def parentRememberDuration: Duration = rememberDuration + windowDuration

怎么获取上一个batch的RDD,可以之在map 中,根据 time - slideDuration 就是上一个批次的时间。

例如 batchInterval = 1s window = 4s slide = 2s

当前 时间 100006

上一个widow 的区间 [ 100001,100002,100003,100004 ]

第一个:需要减去的RDD 区间 [ 100001,100002 ] (区间的表示,下同)

第二个:需要加上的RDD 区间 [ 100005,100006 ]

期望 : [100003,100006 ]

// 这个就是 parent DStream,这边只是做了一个转换。
private val reducedStream = parent.reduceByKey(reduceFunc, partitioner)

override def compute(validTime: Time): Option[RDD[(K, V)]] = {
  val reduceF = reduceFunc
  val invReduceF = invReduceFunc

  val currentTime = validTime
  // 这边计算后就是 [100006 - 4 + 1 ,100006] => [100003 ,100006]
  val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration,
    currentTime)
  // 这边计算后就是 [100003-2,100006-2] => [100001,100004]  上一个Window的区间
  val previousWindow = currentWindow - slideDuration

  //  _____________________________
  // |  previous window   _________|___________________
  // |___________________|       current window        |  --------------> Time
  //                     |_____________________________|
  //
  // |________ _________|          |________ _________|
  //          |                             |
  //          V                             V
  //       old RDDs                     new RDDs
  //

  // Get the RDDs of the reduced values in "old time steps"
  
  // 这边计算后就是 [100001,100004-2] => [100001,100002] 和我们之前预算的一致
  val oldRDDs =
    reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
  logDebug("# old RDDs = " + oldRDDs.size)

  // Get the RDDs of the reduced values in "new time steps"
  //这边计算后就是 [100004+1,100006] 注意 parent.slideDuration = batchInterval 
  val newRDDs =
    reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime)
  logDebug("# new RDDs = " + newRDDs.size)

  // Get the RDD of the reduced value of the previous window
  // 获取上一个window 的RDD
  val previousWindowRDD =
    getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K, V)]()))

  // Make the list of RDDs that needs to cogrouped together for reducing their reduced values
  val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs

  // Cogroup the reduced RDDs and merge the reduced valuesRDD
  // 将 三个RDD cogroupe也就是 合并
  val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]],
    partitioner)
  // val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _

  val numOldValues = oldRDDs.size
  val numNewValues = newRDDs.size

  val mergeValues = (arrayOfValues: Array[Iterable[V]]) => {
    if (arrayOfValues.length != 1 + numOldValues + numNewValues) {
      throw new Exception("Unexpected number of sequences of reduced values")
    }
    // Getting reduced values "old time steps" that will be removed from current window
   // 拿到 需要减去oldrdd 的 值
   val oldValues = (1 to numOldValues).map(i => arrayOfValues(i)).filter(!_.isEmpty).map(_.head)
    // Getting reduced values "new time steps"
    
     // 拿到 需要加上 oldrdd 的 值
    val newValues =
      (1 to numNewValues).map(i => arrayOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
    // 判断上一个Window的Rdd值是不是 
    if (arrayOfValues(0).isEmpty) {
      // If previous window's reduce value does not exist, then at least new values should exist
      if (newValues.isEmpty) {
        throw new Exception("Neither previous window has value for key, nor new values found. " +
          "Are you sure your key class hashes consistently?")
      }
      // Reduce the new values
      // 如果上一个window 是空的,直接只计算新值
      newValues.reduce(reduceF) // return
    } else {
      // Get the previous window's reduced value
      var tempValue = arrayOfValues(0).head
      // If old values exists, then inverse reduce then from previous value
      if (!oldValues.isEmpty) {
         // 减去 oldRDDs,其实应该说对 上一个window 的值和 oldValues 处理
        tempValue = invReduceF(tempValue, oldValues.reduce(reduceF))
      }
      // If new values exists, then reduce them with previous value
      if (!newValues.isEmpty) {
        // 加上 newRdd的值,其实应该说对 上一个window 的值和 newValues 处理
        tempValue = reduceF(tempValue, newValues.reduce(reduceF))
      }
      tempValue // return
    }
  }

  // 调用上面的函数 拿到当前的值
  val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K, Array[Iterable[V]])]]
    .mapValues(mergeValues)

  if (filterFunc.isDefined) {
    Some(mergedValuesRDD.filter(filterFunc.get))
  } else {
    Some(mergedValuesRDD)
  }
}

posted on 2020-02-27 23:11  chouc  阅读(158)  评论(0编辑  收藏  举报

导航