spark源码阅读--shuffle读过程源码分析

shuffle读过程源码分析

上一篇中,我们分析了shuffle在map阶段的写过程。简单回顾一下,主要是将ShuffleMapTask计算的结果数据在内存中按照分区和key进行排序,过程中由于内存限制会溢写出多个磁盘文件,最后会对所有的文件和内存中剩余的数据进行归并排序并溢写到一个文件中,同时会记录每个分区(reduce端分区)的数据在文件中的偏移,并且把分区和偏移的映射关系写到一个索引文件中。
好了,简单回顾了写过程后,我们不禁思考,reduce阶段的数据读取的具体过程是什么样的?数据读取的发生的时机是什么?

首先应该回答后一个问题:数据读取发生的时机是什么?我们知道,rdd的计算链根据shuffle被切分为不同的stage,一个stage的开始阶段一般就是从读取上一阶段的数据开始,也就是说stage读取数据的过程其实就是reduce过程,然后经过该stage的计算链后得到结果数据,再然后就会把这些数据写入到磁盘供下一个stage读取,这个写入的过程实际上就是map输出过程,而这个过程我们之前已经分析过了。本篇我们要分析的是reduce阶段读取数据的过程。

啰嗦了这么一大段,其实就是为了引出数据读取的入口,还是要回到ShuffleMapTask,这里我只贴部分代码:

  // shuffle管理器
  val manager = SparkEnv.get.shuffleManager
  // 获取一个shuffle写入器
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  // 这里可以看到rdd计算的核心方法就是iterator方法
  // SortShuffleWriter的write方法可以分为几个步骤:
  // 将上游rdd计算出的数据(通过调用rdd.iterator方法)写入内存缓冲区,
  // 在写的过程中如果超过 内存阈值就会溢写磁盘文件,可能会写多个文件
  // 最后将溢写的文件和内存中剩余的数据一起进行归并排序后写入到磁盘中形成一个大的数据文件
  // 这个排序是先按分区排序,在按key排序
  // 在最后归并排序后写的过程中,没写一个分区就会手动刷写一遍,并记录下这个分区数据在文件中的位移
  // 所以实际上最后写完一个task的数据后,磁盘上会有两个文件:数据文件和记录每个reduce端partition数据位移的索引文件
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  // 主要是删除中间过程的溢写文件,向内存管理器释放申请的内存
  writer.stop(success = true).get

读取数据的代码其实就是rdd.iterator(partition, context),
iterator方法主要是处理rdd缓存的逻辑,如果有缓存就会从缓存中读取(通过BlockManager),如果没有缓存就会进行实际的计算,发现最终调用RDD.compute方法进行实际的计算,这个方法是一个抽象方法,是由子类实现的具体的计算逻辑,用户代码中对于RDD做的一些变换操作实际上最终都会体现在compute方法中。
另一方面,我们知道,map,filter这类算子不是shuffle操作,不会导致stage的划分,所以我们想看shuffle读过程就要找一个Shuffle类型的操作,我们看一下RDD.groupBy,最终调用了groupByKey方法

RDD.groupByKey

def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
val createCombiner = (v: V) => CompactBuffer(v)
val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v
val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2
val bufs = combineByKeyWithClassTag[CompactBuffer[V]](
  createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false)
bufs.asInstanceOf[RDD[(K, Iterable[V])]]
}

最终调用了combineByKeyWithClassTag

RDD.combineByKeyWithClassTag

做一些判断,检查一些非法情况,然后处理一下分区器,最后返回一个ShuffledRDD,所以接下来我们分析一下ShuffleRDD的compute方法

def combineByKeyWithClassTag[C](
  createCombiner: V => C,
  mergeValue: (C, V) => C,
  mergeCombiners: (C, C) => C,
  partitioner: Partitioner,
  mapSideCombine: Boolean = true,
  serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
// 如果key是Array类型,是不支持在map端合并的
// 并且也不支持HashPartitioner
if (keyClass.isArray) {
  if (mapSideCombine) {
    throw new SparkException("Cannot use map-side combining with array keys.")
  }
  if (partitioner.isInstanceOf[HashPartitioner]) {
    throw new SparkException("HashPartitioner cannot partition array keys.")
  }
}
// 聚合器,用于对数据进行聚合
val aggregator = new Aggregator[K, V, C](
  self.context.clean(createCombiner),
  self.context.clean(mergeValue),
  self.context.clean(mergeCombiners))
// 如果分区器相同,就不需要shuffle了
if (self.partitioner == Some(partitioner)) {
  self.mapPartitions(iter => {
    val context = TaskContext.get()
    new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
  }, preservesPartitioning = true)
} else {
  // 返回一个ShuffledRDD
  new ShuffledRDD[K, V, C](self, partitioner)
    .setSerializer(serializer)
    .setAggregator(aggregator)
    .setMapSideCombine(mapSideCombine)
}
}

ShuffleRDD.compute

通过shuffleManager获取一个读取器,数据读取的逻辑在读取器里。

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// 通过shuffleManager获取一个读取器
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
  .read()
  .asInstanceOf[Iterator[(K, C)]]
}

SortShuffleManager.getReader

无需多说,直接看BlockStoreShuffleReader

override def getReader[K, C](
  handle: ShuffleHandle,
  startPartition: Int,
  endPartition: Int,
  context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
  handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

BlockStoreShuffleReader.read

显然,这个方法才是核心所在。总结一下主要步骤:

  • 获取一个包装的迭代器ShuffleBlockFetcherIterator,它迭代的元素是blockId和这个block对应的读取流,很显然这个类就是实现reduce阶段数据读取的关键
  • 将原始读取流转换成反序列化后的迭代器
  • 将迭代器转换成能够统计度量值的迭代器,这一系列的转换和java中对于流的各种装饰器很类似
  • 将迭代器包装成能够相应中断的迭代器。每读一条数据就会检查一下任务有没有被杀死,这种做法是为了尽量及时地响应杀死任务的请求,比如从driver端发来杀死任务的消息。
  • 利用聚合器对结果进行聚合。这里再次利用了AppendonlyMap这个数据结构,前面shuffle写阶段也用到这个数据结构,它的内部是一个以数组作为底层数据结构的,以线性探测法线性的hash表。
  • 最后对结果进行排序。

所以很显然,我们想知道的shuffle读取数据的具体逻辑就藏在ShuffleBlockFetcherIterator中

    private[spark] class BlockStoreShuffleReader[K, C](
        handle: BaseShuffleHandle[K, _, C],
        startPartition: Int,
        endPartition: Int,
        context: TaskContext,
        serializerManager: SerializerManager = SparkEnv.get.serializerManager,
        blockManager: BlockManager = SparkEnv.get.blockManager,
        mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
      extends ShuffleReader[K, C] with Logging {
    
      private val dep = handle.dependency
    
      /** Read the combined key-values for this reduce task */
      override def read(): Iterator[Product2[K, C]] = {
        // 获取一个包装的迭代器,它迭代的元素是blockId和这个block对应的读取流
        val wrappedStreams = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
          serializerManager.wrapStream,
          // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
          SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
          SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
          SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
          SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
          SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
    
        val serializerInstance = dep.serializer.newInstance()
    
        // Create a key/value iterator for each stream
        // 将原始读取流转换成反序列化后的迭代器
        val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
          // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
          // NextIterator. The NextIterator makes sure that close() is called on the
          // underlying InputStream when all records have been read.
          serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
        // Update the context task metrics for each record read.
        val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
        // 转换成能够统计度量值的迭代器,这一系列的转换和java中对于流的各种装饰器很类似
        val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
          recordIter.map { record =>
            readMetrics.incRecordsRead(1)
            record
          },
          context.taskMetrics().mergeShuffleReadMetrics())
    
        // An interruptible iterator must be used here in order to support task cancellation
        // 每读一条数据就会检查一下任务有没有被杀死,
        // 这种做法是为了尽量及时地响应杀死任务的请求,比如从driver端发来杀死任务的消息
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          // 利用聚合器对结果进行聚合
          if (dep.mapSideCombine) {
            // We are reading values that are already combined
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            // We don't know the value type, but also don't care -- the dependency *should*
            // have made sure its compatible w/ this aggregator, which will convert the value
            // type to the combined type C
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // Sort the output if there is a sort ordering defined.
        // 最后对结果进行排序
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            // Create an ExternalSorter to sort the data.
            val sorter =
              new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
            sorter.insertAll(aggregatedIter)
            context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
            context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
            context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
            CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
          case None =>
            aggregatedIter
        }
      }
    }

ShuffleBlockFetcherIterator

这个类比较复杂,仔细看在类初始化的代码中会调用initialize方法。
其次,我们应该注意它的构造器中的参数,

    val wrappedStreams = new ShuffleBlockFetcherIterator(
    context,
    // 如果没有启用外部shuffle服务,就是BlockTransferService
    blockManager.shuffleClient,
    blockManager,
    // 通过mapOutputTracker组件获取每个分区对应的数据block的物理位置
    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
    serializerManager.wrapStream,
    // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
    // 获取几个配置参数
    SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
    SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
    SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
    SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
    SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

ShuffleBlockFetcherIterator.initialize

  • 首先将本地的block和远程的block分隔开
  • 然后开始发送请求拉取远程数据。这个过程中会有一些约束条件限制拉取数据请求的数量,主要是正在获取的总数据量的限制,请求并发数限制;每个远程地址同时拉取的块数也会有限制,但是这个阈值默认是Integer.MAX_VALUE
  • 获取本地的block数据

其中,获取本地数据较为简单,主要就是通过本节点的BlockManager来获取块数据,并通过索引文件获取数据指定分区的数据。
我们着重分析远程拉取的部分

private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
// 向TaskContext中添加一个回调,在任务完成时做一些清理工作
context.addTaskCompletionListener(_ => cleanup())

// Split local and remote blocks.
// 将本地的block和远程的block分隔开
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
  "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
  ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

// Send out initial requests for blocks, up to our maxBytesInFlight
// 发送远程拉取数据的请求
// 尽可能多地发送请求
// 但是会有一定的约束:
// 全局性的约束,全局拉取数据的rpc线程并发数,全局拉取数据的数据量限制
// 每个远程地址的限制:每个远程地址同时拉取的块数不能超过一定阈值
fetchUpToMaxBytes()

// 记录已经发送的请求个数,仍然会有一部分没有发送请求
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

// Get Local Blocks
// 获取本地的block数据
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

ShuffleBlockFetcherIterator.splitLocalRemoteBlocks

我们首先来看如何切分远程和本地的数据块,总结一下这个方法:

  • 首先将同时拉取的数据量的大小除以5作为每次请求拉取的数据量的限制,这么做的原因是为了允许同时从5个节点拉取数据,因为节点的网络环境可能并不稳定,同时从多个节点拉取数据有助于减少网络波动对性能带来的影响,而对整体的同时拉取数据量的限制主要是为了限制本机网络流量的使用

  • 循环遍历每一个节点地址(这里是BlockManagerId),

  • 如果地址与本机地址相同,那么对应的blocks就是本地block

  • 对于远程block,则要根据同时拉取数据量大小的限制将每个节点的所有block切分成多个请求(FetchRequest),确保这些请求单次的拉取数据量不会太大

      private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
      // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
      // nodes, rather than blocking on reading output from one node.
      // 之所以将请求大小减小到maxBytesInFlight / 5,
      // 是为了并行化地拉取数据,最毒允许同时从5个节点拉取数据
      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
      logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
        + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
      
      // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
      // at most maxBytesInFlight in order to limit the amount of data in flight.
      val remoteRequests = new ArrayBuffer[FetchRequest]
      
      // Tracks total number of blocks (including zero sized blocks)
      // 记录总的block数量
      var totalBlocks = 0
      for ((address, blockInfos) <- blocksByAddress) {
        totalBlocks += blockInfos.size
        // 如果地址与本地的BlockManager相同,就是本地block
        if (address.executorId == blockManager.blockManagerId.executorId) {
          // Filter out zero-sized blocks
          localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
          numBlocksToFetch += localBlocks.size
        } else {
          val iterator = blockInfos.iterator
          var curRequestSize = 0L
          var curBlocks = new ArrayBuffer[(BlockId, Long)]
          while (iterator.hasNext) {
            val (blockId, size) = iterator.next()
            // Skip empty blocks
            if (size > 0) {
              curBlocks += ((blockId, size))
              remoteBlocks += blockId
              numBlocksToFetch += 1
              curRequestSize += size
            } else if (size < 0) {
              throw new BlockException(blockId, "Negative block size " + size)
            }
            // 如果超过每次请求的数据量限制,那么创建一次请求
            if (curRequestSize >= targetRequestSize ||
                curBlocks.size >= maxBlocksInFlightPerAddress) {
              // Add this FetchRequest
              remoteRequests += new FetchRequest(address, curBlocks)
              logDebug(s"Creating fetch request of $curRequestSize at $address "
                + s"with ${curBlocks.size} blocks")
              curBlocks = new ArrayBuffer[(BlockId, Long)]
              curRequestSize = 0
            }
          }
          // Add in the final request
          // 扫尾方法,最后剩余的块创建一次请求
          if (curBlocks.nonEmpty) {
            remoteRequests += new FetchRequest(address, curBlocks)
          }
        }
      }
      logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
      remoteRequests
      }
    

ShuffleBlockFetcherIterator.fetchUpToMaxBytes

回到initialize方法中,在完成本地与远程block的切分后,我们得到了一批封装好的数据拉取请求,将这些请求加到队列中,接下来要做的是通过rpc客户端发送这些请求,

这个方法逻辑还是相对简单,主要逻辑就是两个循环,先发送延缓队列中的请求,然后发送正常的请求;之所以会有延缓队列是因为这些请求在第一次待发送时因为数据量超过阈值或者请求数量超过阈值而不能发送,所以就被放到延缓队列中,而这里的处理也是优先发送延缓队列中的请求。每个请求在发送前必须要满足下面几个条件才会被发送:

  • 当前正在拉取的数据量不能超过阈值maxReqsInFlight(默认48m);这里会有一个问题,如果某个block的数据量超过maxReqsInFlight值呢?这种情况下会等当前已经没有进行中的数据拉取请求才会发送这个请求,因为在对当前请求数据量阈值进行判断时会检查bytesInFlight == 0,如果这个条件满足就不会检查本次请求的数据量是否会超过阈值。

  • 当前正在拉取的请求数据量不能超过阈值(默认Int.MaxValue)

  • 每个远程地址的同时请求数量也会有限制(默认Int.MaxValue)

  • 最后符合条件的请求就会被发送,这里要提出的一点是如果一次请求的数据量超过maxReqSizeShuffleToMem值,那么就会写入磁盘的一个临时文件中,而这个阈值的默认值是Long.MaxValue,所以默认情况下是没有限制的。

      // 发送请求
      // 尽可能多地发送请求
      // 但是会有一定的约束:
      // 全局性的约束,全局拉取数据的rpc线程并发数,全局拉取数据的数据量限制
      // 每个远程地址的限制:每个远程地址同时拉取的块数不能超过一定阈值
      private def fetchUpToMaxBytes(): Unit = {
      // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
      // immediately, defer the request until the next time it can be processed.
      
      // Process any outstanding deferred fetch requests if possible.
      if (deferredFetchRequests.nonEmpty) {
        for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
          while (isRemoteBlockFetchable(defReqQueue) &&
              !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
            val request = defReqQueue.dequeue()
            logDebug(s"Processing deferred fetch request for $remoteAddress with "
              + s"${request.blocks.length} blocks")
            send(remoteAddress, request)
            if (defReqQueue.isEmpty) {
              deferredFetchRequests -= remoteAddress
            }
          }
        }
      }
      
      // Process any regular fetch requests if possible.
      while (isRemoteBlockFetchable(fetchRequests)) {
        val request = fetchRequests.dequeue()
        val remoteAddress = request.address
        // 如果超过了同时拉取的块数的限制,那么将这个请求放到延缓队列中,留待下次请求
        if (isRemoteAddressMaxedOut(remoteAddress, request)) {
          logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
          val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
          defReqQueue.enqueue(request)
          deferredFetchRequests(remoteAddress) = defReqQueue
        } else {
          send(remoteAddress, request)
        }
      }
      
      // 发送一个请求,并且累加记录请求的块的数量,
      // 以用于在下次请求时检查请求块的数量是否超过阈值
      def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
        sendRequest(request)
        numBlocksInFlightPerAddress(remoteAddress) =
          numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
      }
      
      // 这个限制是对所有的请求而言,不分具体是哪个远程节点
      // 检查当前的请求的数量是否还有余量
      // 当前请求的大小是否还有余量
      // 这主要是为了限制并发数和网络流量的使用
      def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
        fetchReqQueue.nonEmpty &&
          (bytesInFlight == 0 ||
            (reqsInFlight + 1 <= maxReqsInFlight &&
              bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
      }
      
      // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
      // given remote address.
      // 检测正在拉取的块的数量是否超过阈值
      // 每个地址都有一个同事拉取块数的限制
      def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
          maxBlocksInFlightPerAddress
      }
      }
    

ShuffleBlockFetcherIterator.next

通过上一个方法的分析,我们能够看出来,初始化时发起的拉取数据的请求并未将所有请求全部发送出去,并且还会有请求因为超过阈值而被放入延缓队列中,那么这些未发送的请求是什么时候被再次发送的呢?答案就在next方法中。我们知道ShuffleBlockFetcherIterator是一个迭代器,所以外部调用者对元素的访问是通过next方法,所以很容易想到next方法中肯定会有发送拉取数据请求的逻辑。
总结一下:

  • 首先从结果队列中获取一个拉取成功的结果(结果队列是一个阻塞队列,如果没有拉取成功的结果会阻塞调用者)

  • 拿到一个结果后检查这个结果是拉取成功还是拉取失败,如果失败则直接抛异常(重试的逻辑实在rpc客户端实现的,不是在这里实现)

  • 如果是一个成功的结果,首先要更新一下一些任务度量值,更新一些内部的簿记量,如正在拉取的数据量

  • 将拉取到的字节缓冲包装成一个字节输入流

  • 通过外部传进来的函数对流再包装一次,通过外部传进来的函数再包装一次,一般是解压缩和解密

  • 而且流被压缩或者加密过,如果块的大小比较小,那么要将这个流拷贝一份,这样就会实际出发解压缩和解密,以此来尽早暴露块损坏的 问题

  • 最后一句关键语句,再次发起一轮拉取数据请求的发 送,因为经过next处理之后,已经有拉取成功的数据了,正在拉取的数据量和请求数量可能减小了,这就为发送新的请求腾出空间

      override def next(): (BlockId, InputStream) = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
      
      numBlocksProcessed += 1
      
      var result: FetchResult = null
      var input: InputStream = null
      // Take the next fetched result and try to decompress it to detect data corruption,
      // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
      // is also corrupt, so the previous stage could be retried.
      // For local shuffle block, throw FailureFetchResult for the first IOException.
      while (result == null) {
        val startFetchWait = System.currentTimeMillis()
        result = results.take()
        val stopFetchWait = System.currentTimeMillis()
        shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
      
        result match {
          case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
            if (address != blockManager.blockManagerId) {
              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
              // 主要是更新一些度量值
              shuffleMetrics.incRemoteBytesRead(buf.size)
              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
              }
              shuffleMetrics.incRemoteBlocksFetched(1)
            }
            bytesInFlight -= size
            if (isNetworkReqDone) {
              reqsInFlight -= 1
              logDebug("Number of requests in flight " + reqsInFlight)
            }
      
            // 将字节缓冲包装成一个字节输入流
            val in = try {
              buf.createInputStream()
            } catch {
              // The exception could only be throwed by local shuffle block
              case e: IOException =>
                assert(buf.isInstanceOf[FileSegmentManagedBuffer])
                logError("Failed to create input stream from local block", e)
                buf.release()
                throwFetchFailedException(blockId, address, e)
            }
      
            // 通过外部传进来的函数再包装一次,一般是增加压缩和加密的功能
            input = streamWrapper(blockId, in)
            // Only copy the stream if it's wrapped by compression or encryption, also the size of
            // block is small (the decompressed block is smaller than maxBytesInFlight)
            // 如果块的大小比较小,而且流被压缩或者加密过,那么需要将这个流拷贝一份
            if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
              val originalInput = input
              val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
              try {
                // Decompress the whole block at once to detect any corruption, which could increase
                // the memory usage tne potential increase the chance of OOM.
                // TODO: manage the memory used here, and spill it into disk in case of OOM.
                Utils.copyStream(input, out)
                out.close()
                input = out.toChunkedByteBuffer.toInputStream(dispose = true)
              } catch {
                case e: IOException =>
                  buf.release()
                  if (buf.isInstanceOf[FileSegmentManagedBuffer]
                    || corruptedBlocks.contains(blockId)) {
                    throwFetchFailedException(blockId, address, e)
                  } else {
                    logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
                    corruptedBlocks += blockId
                    fetchRequests += FetchRequest(address, Array((blockId, size)))
                    result = null
                  }
              } finally {
                // TODO: release the buf here to free memory earlier
                originalInput.close()
                in.close()
              }
            }
      
            // 拉取失败,抛异常
            // 这里思考一下:拉取块数据肯定是有重试机制的,但是这里拉取失败之后直接抛异常是为何??
            // 答案是:重试机制并不是正在这里实现 的,而是在rpc客户端发送拉取请求时实现了重试机制
            // 也就是说如果到这里是失败的话,说明已经经过重试后还是失败的,所以这里直接抛异常就行了
          case FailureFetchResult(blockId, address, e) =>
            throwFetchFailedException(blockId, address, e)
        }
      
        // Send fetch requests up to maxBytesInFlight
        // 这里再次发送拉取请求,因为前面已经有成功拉取到的数据,
        // 所以正在拉取中的数据量就会减小,所以就能为新的请求腾出空间
        fetchUpToMaxBytes()
      }
      
      currentResult = result.asInstanceOf[SuccessFetchResult]
      (currentResult.blockId, new BufferReleasingInputStream(input, this))
      }
    

总结

到此,我们就把shuffle读的过程大概分析完了。整体下来,感觉主干逻辑不是很复杂,但是里面有很多细碎逻辑,所以上面的分析还是比较碎,这里把整个过程的主干逻辑再提炼一下,以便能有个整体的认识:

  • 首先,在一些shuffle类型的RDD中,它的计算方法compute会通过ShuffleManager获取一个block数据读取器BlockStoreShuffleReader
  • 通过BlockStoreShuffleReader中的read方法进行数据的读取,一个reduce端分区的数据一般会依赖于所有的map端输出的分区数据,所以数据一般会在多个executor(注意是executor节点,通过BlockManagerId唯一标识,一个物理节点可能会运行多个executor节点)节点上,而且每个executor节点也可能会有多个block,在shuffle写过程的分析中我们也提到,每个map最后时输出一个数据文件和索引文件,也就是一个block,但是因为一个节点
  • 这个方法通过ShuffleBlockFetcherIterator对象封装了远程拉取数据的复杂逻辑,并且最终将拉取到的数据封装成流的迭代器的形式
  • 对所有的block的流进行层层装饰,包括反序列化,任务度量值(读入数据条数)统计,每条数据可中断,
  • 对数据进行聚合
  • 对聚合后的数据进行排序

所以,从这里我们也能看出来,新版的shuffle机制中,也就是SortShuffleManager,用户代码对于shuffle之后的rdd拿到的是经过排序的数据(如果指定排序器的话)。

posted on 2019-06-16 19:50  _朱葛  阅读(589)  评论(0编辑  收藏  举报