spark任务在executor端的运行过程分析

CoarseGrainedExecutorBackend

上一篇,我们主要分析了一次作业的提交过程,严格说是在driver端的过程,作业提交之后经过DAGScheduler根据shuffle依赖关系划分成多个stage,依次提交每个stage,将每个stage创建于分区数相同数量的Task,并包装成一个任务集,交给TaskSchedulerImpl进行分配。TaskSchedulerImpl则会根据SchedulerBackEnd提供的计算资源(executor),并考虑任务本地性,黑名单,调度池的调度顺序等因素对任务按照round-robin的方式进行分配,并将Task与executor的分配关系包装成TaskDescription返回给SchedulerBackEnd。然后SchedulerBackEnd就会根据收到的TaskDescription将任务再次序列化之后发送到对应的executor上执行。本篇,我们就来分析一下Task在executor上的执行过程。

任务执行入口Executor.launchTask

首先,我们知道CoarseGrainedExecutorBackend是yarn模式下的executor的实现类,这时一个rpc服务端,所以我们根据rpc客户端也就是CoarseGraineSchedulerBackEnd发送的消息,然后在服务端找到处理对应消息的方法,顺藤摸瓜就能找到Task执行的入口。通过上一篇的分析知道发送任务时,CoarseGraineSchedulerBackEnd发送的是一个LaunchTask类型的消息,我们看一下CoarseGrainedExecutorBackend.receive方法,其中对于LaunchTask消息的处理如下:

case LaunchTask(data) =>
  if (executor == null) {
    exitExecutor(1, "Received LaunchTask command but executor was null")
  } else {
    val taskDesc = TaskDescription.decode(data.value)
    logInfo("Got assigned task " + taskDesc.taskId)
    executor.launchTask(this, taskDesc)
  }

可以看到,实际上任务时交给内部的Executor对象来处理,实际上Executor对象承担了executor端的绝大部分逻辑,可以认为CoarseGrainedExecutorBackend仅仅是充当rpc消息中转的角色,充当spark的rpc框架中端点的角色,而实际的任务执行的逻辑则是由Executor对象来完成的。

Executor概述

我们先来看一下Executor类的说明:

/**
 * Spark executor, backed by a threadpool to run tasks.
 *
 * This can be used with Mesos, YARN, and the standalone scheduler.
 * An internal RPC interface is used for communication with the driver,
 * except in the case of Mesos fine-grained mode.
 */

Executor内部有一个线程池用来运行任务,Mesos, YARN, 和 standalone模式都是用这个类作为任务运行的逻辑。此外Executor对象持有SparkEnv的引用,以此来使用spark的一些基础设施,包括rpc引用。
我们还是以任务运行为线索分析这个类的代码。

Executor.launchTask

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
  val tr = new TaskRunner(context, taskDescription)
  runningTasks.put(taskDescription.taskId, tr)
  threadPool.execute(tr)
}

这个代码没什么好说的,应该没人看不懂吧。所以接下来我们就看一下TaskRunner这个类。
从这个地方也能看出来,在executor端,一个task对应一个线程。

TaskRunner.run

这个方法贼长,没有一点耐心还真不容易看完。
其中有一些统计量我就不说了,比如任务运行时间统计,cpu耗时统计,gc耗时统计等等,这里有一点可以积累的地方是MXBean,cpu,gc耗时都是通过获取jvm内置的相关的MXBean获取到的,入口类是ManagementFactory,具体的可以细看,这里不再展开。

总结一下这个方法的主要步骤:

  • 首先向driver发送一个更新任务状态的消息,通知driver这个task处于运行的状态。

  • 设置任务属性,更新依赖的文件和jar包,将新的jar包添加到类加载器的寻找路径中;注意这些信息都是从driver端跟着TaskDescription一起传过来的。

  • 对任务进行反序列化生成Task对象,根据任务类型可能是ShuffleMapTask或者ResultTask

  • 检查任务有没有被杀死,如果被杀死则跑一个异常;(driver随时都可能发送一个杀死任务的消息)

  • 调用Task.run方法执行任务的运行逻辑

  • 任务运行结束后,清除未正常释放的内存资源和block锁资源,并在需要的时候打印资源泄漏的告警日志和抛出异常

  • 再次检测任务是否被杀死

  • 将任务运行的结果数据序列化

  • 更新一些任务统计量(一些累加器),以及更新度量系统中的相关统计量

  • 收集该任务相关的所有累加器(包括内置的统计量累加器和用户注册的累加器)

  • 将累加器数据和任务结果数据封装成一个对象并在此序列化

  • 检测序列化后的体积,有两个阈值:maxResultSize和maxDirectResultSize,如果超过maxResultSize直接丢弃结果,就是不往blockmanager里面写数据,这样driver端在试图通过blockmanager远程拉取数据的时候就获取不到数据,这时driver就知道这个任务的结果数据太大,失败了;而对于体积超过maxDirectResultSize的情况,会将任务结果数据通过blockmanager写到本地内存和磁盘,然后将block信息发送给driver,driver会根据这些信息来这个节点拉取数据;如果体积小于maxDirectResultSize,则直接通过rpc接口将结果数据发送给driver。

  • 最后还会有对任务失败的各种总异常的处理。

    override def run(): Unit = {
    threadId = Thread.currentThread.getId
    Thread.currentThread.setName(threadName)
    // 监控线程运行情况的MXBean
    val threadMXBean = ManagementFactory.getThreadMXBean
    // 内存管理器
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
    // 记录反序列化的耗时,回忆一下,我们再spark的UI界面上可以看到这个统计值,原来就是在这里统计的
    val deserializeStartTime = System.currentTimeMillis()
    // 统计反序列化的cpu耗时
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
    } else 0L
    Thread.currentThread.setContextClassLoader(replClassLoader)
    val ser = env.closureSerializer.newInstance()
    logInfo(s"Running $taskName (TID $taskId)")
    // TODO 通过executor后端向driver发送一个任务状态更新的消息
    execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    var taskStart: Long = 0
    var taskStartCpu: Long = 0
    // 依然是通过MXBean获取gc总时长
    startGCTime = computeTotalGcTime()

    try {
      // Must be set before updateDependencies() is called, in case fetching dependencies
      // requires access to properties contained within (e.g. for access control).
      Executor.taskDeserializationProps.set(taskDescription.properties)
    
      // TODO 更新依赖的文件和jar包,从driver端拉取到本地,并缓存下来
      updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
      // 对任务进行反序列化,这里却并没有进行耗时统计
      task = ser.deserialize[Task[Any]](
        taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
      // 属性集合也是从driver端跟随taskDescription一起发送过来的
      task.localProperties = taskDescription.properties
      // 设置内存管理器
      task.setTaskMemoryManager(taskMemoryManager)
    
      // If this task has been killed before we deserialized it, let's quit now. Otherwise,
      // continue executing the task.
      // 检查有没有被杀掉
      val killReason = reasonIfKilled
      if (killReason.isDefined) {
        // Throw an exception rather than returning, because returning within a try{} block
        // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
        // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
        // for the task.
        throw new TaskKilledException(killReason.get)
      }
    
      // The purpose of updating the epoch here is to invalidate executor map output status cache
      // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
      // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
      // we don't need to make any special calls here.
      //
      if (!isLocal) {
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        // 更新epoch值和map输出状态
        env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
      }
    
      // Run the actual task and measure its runtime.
      // 运行任务并统计运行时间
      taskStart = System.currentTimeMillis()
      // 统计当前线程的cpu耗时
      taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
      var threwException = true
      val value = try {
        // 调用task.run方法运行任务
        val res = task.run(
          // 任务id
          taskAttemptId = taskId,
          // 任务的尝试次数
          attemptNumber = taskDescription.attemptNumber,
          // 度量系统
          metricsSystem = env.metricsSystem)
        threwException = false
        res
      } finally {
        // 释放关于该任务的所有锁, 该任务相关的block的读写锁
        val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
        // 清除所有分配给该任务的内存空间
        val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
    
        // 如果threwException为false,说明任务正常运行完成
        // 在任务正常运行完的前提下如果还能够释放出内存,
        // 说明在任务正常执行的过程中没有正确地释放使用的内存,也就是发生了内存泄漏
        if (freedMemory > 0 && !threwException) {
          val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
          if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logWarning(errMsg)
          }
        }
    
        // 这里对于锁资源的检测和内存资源的检测是相同的逻辑
        // spark作者认为,具体的任务应该自己负责将申请的资源(包括内存和锁资源)在使用完后释放掉,
        // 不能依赖于靠后面的补救措施
        // 如果没有正常释放,就发生了资源泄漏
        // 这里则是对锁锁资源泄漏的检查
        if (releasedLocks.nonEmpty && !threwException) {
          val errMsg =
            s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
              releasedLocks.mkString("[", ", ", "]")
          if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
            throw new SparkException(errMsg)
          } else {
            logInfo(errMsg)
          }
        }
      }
      // 打印拉取异常日志
      // 代码执行到这里说明用户并没有抛拉取异常
      // 但是框架检测到拉取异常,这说明用户把拉取异常吞了,这显然是错误的行为,
      // 因此需要打印一条错误日志提醒用户
      task.context.fetchFailed.foreach { fetchFailure =>
        // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
        // other exceptions.  Its *possible* this is what the user meant to do (though highly
        // unlikely).  So we will log an error and keep going.
        logError(s"TID ${taskId} completed successfully though internally it encountered " +
          s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
          s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
      }
      // 统计任务完成时间
      val taskFinish = System.currentTimeMillis()
      // 统计任务线程占用的cpu时间
      val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
      } else 0L
    
      // If the task has been killed, let's fail it.
      // 再次检测任务是否被杀掉
      task.context.killTaskIfInterrupted()
    
      // 任务结果的序列化
      val resultSer = env.serializer.newInstance()
      val beforeSerialization = System.currentTimeMillis()
      val valueBytes = resultSer.serialize(value)
      val afterSerialization = System.currentTimeMillis()
    
      // Deserialization happens in two parts: first, we deserialize a Task object, which
      // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
      task.metrics.setExecutorDeserializeTime(
        (taskStart - deserializeStartTime) + task.executorDeserializeTime)
      task.metrics.setExecutorDeserializeCpuTime(
        (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
      // We need to subtract Task.run()'s deserialization time to avoid double-counting
      task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
      task.metrics.setExecutorCpuTime(
        (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
      task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
      task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
    
      // Expose task metrics using the Dropwizard metrics system.
      // Update task metrics counters
      executorSource.METRIC_CPU_TIME.inc(task.metrics.executorCpuTime)
      executorSource.METRIC_RUN_TIME.inc(task.metrics.executorRunTime)
      executorSource.METRIC_JVM_GC_TIME.inc(task.metrics.jvmGCTime)
      executorSource.METRIC_DESERIALIZE_TIME.inc(task.metrics.executorDeserializeTime)
      executorSource.METRIC_DESERIALIZE_CPU_TIME.inc(task.metrics.executorDeserializeCpuTime)
      executorSource.METRIC_RESULT_SERIALIZE_TIME.inc(task.metrics.resultSerializationTime)
      executorSource.METRIC_SHUFFLE_FETCH_WAIT_TIME
        .inc(task.metrics.shuffleReadMetrics.fetchWaitTime)
      executorSource.METRIC_SHUFFLE_WRITE_TIME.inc(task.metrics.shuffleWriteMetrics.writeTime)
      executorSource.METRIC_SHUFFLE_TOTAL_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.totalBytesRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.remoteBytesRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BYTES_READ_TO_DISK
        .inc(task.metrics.shuffleReadMetrics.remoteBytesReadToDisk)
      executorSource.METRIC_SHUFFLE_LOCAL_BYTES_READ
        .inc(task.metrics.shuffleReadMetrics.localBytesRead)
      executorSource.METRIC_SHUFFLE_RECORDS_READ
        .inc(task.metrics.shuffleReadMetrics.recordsRead)
      executorSource.METRIC_SHUFFLE_REMOTE_BLOCKS_FETCHED
        .inc(task.metrics.shuffleReadMetrics.remoteBlocksFetched)
      executorSource.METRIC_SHUFFLE_LOCAL_BLOCKS_FETCHED
        .inc(task.metrics.shuffleReadMetrics.localBlocksFetched)
      executorSource.METRIC_SHUFFLE_BYTES_WRITTEN
        .inc(task.metrics.shuffleWriteMetrics.bytesWritten)
      executorSource.METRIC_SHUFFLE_RECORDS_WRITTEN
        .inc(task.metrics.shuffleWriteMetrics.recordsWritten)
      executorSource.METRIC_INPUT_BYTES_READ
        .inc(task.metrics.inputMetrics.bytesRead)
      executorSource.METRIC_INPUT_RECORDS_READ
        .inc(task.metrics.inputMetrics.recordsRead)
      executorSource.METRIC_OUTPUT_BYTES_WRITTEN
        .inc(task.metrics.outputMetrics.bytesWritten)
      executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
        .inc(task.metrics.inputMetrics.recordsRead)
      executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
      executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
      executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
    
      // Note: accumulator updates must be collected after TaskMetrics is updated
      // 这里手机
      val accumUpdates = task.collectAccumulatorUpdates()
      // TODO: do not serialize value twice
      val directResult = new DirectTaskResult(valueBytes, accumUpdates)
      val serializedDirectResult = ser.serialize(directResult)
      val resultSize = serializedDirectResult.limit()
    
      // directSend = sending directly back to the driver
      val serializedResult: ByteBuffer = {
        if (maxResultSize > 0 && resultSize > maxResultSize) {
          logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
            s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
            s"dropping it.")
          ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
        } else if (resultSize > maxDirectResultSize) {
          val blockId = TaskResultBlockId(taskId)
          env.blockManager.putBytes(
            blockId,
            new ChunkedByteBuffer(serializedDirectResult.duplicate()),
            StorageLevel.MEMORY_AND_DISK_SER)
          logInfo(
            s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
          ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
        } else {
          logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
          serializedDirectResult
        }
      }
    
      setTaskFinishedAndClearInterruptStatus()
      execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
    
    } catch {
      case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
        val reason = task.context.fetchFailed.get.toTaskFailedReason
        if (!t.isInstanceOf[FetchFailedException]) {
          // there was a fetch failure in the task, but some user code wrapped that exception
          // and threw something else.  Regardless, we treat it as a fetch failure.
          val fetchFailedCls = classOf[FetchFailedException].getName
          logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
            s"failed, but the ${fetchFailedCls} was hidden by another " +
            s"exception.  Spark is handling this like a fetch failure and ignoring the " +
            s"other exception: $t")
        }
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
    
      case t: TaskKilledException =>
        logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
    
      case _: InterruptedException | NonFatal(_) if
          task != null && task.reasonIfKilled.isDefined =>
        val killReason = task.reasonIfKilled.getOrElse("unknown reason")
        logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(
          taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
    
      case CausedBy(cDE: CommitDeniedException) =>
        val reason = cDE.toTaskCommitDeniedReason
        setTaskFinishedAndClearInterruptStatus()
        execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
    
      case t: Throwable =>
        // Attempt to exit cleanly by informing the driver of our failure.
        // If anything goes wrong (or this was a fatal exception), we will delegate to
        // the default uncaught exception handler, which will terminate the Executor.
        logError(s"Exception in $taskName (TID $taskId)", t)
    
        // SPARK-20904: Do not report failure to driver if if happened during shut down. Because
        // libraries may set up shutdown hooks that race with running tasks during shutdown,
        // spurious failures may occur and can result in improper accounting in the driver (e.g.
        // the task failure would not be ignored if the shutdown happened because of premption,
        // instead of an app issue).
        if (!ShutdownHookManager.inShutdown()) {
          // Collect latest accumulator values to report back to the driver
          val accums: Seq[AccumulatorV2[_, _]] =
            if (task != null) {
              task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
              task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
              task.collectAccumulatorUpdates(taskFailed = true)
            } else {
              Seq.empty
            }
    
          val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
    
          val serializedTaskEndReason = {
            try {
              ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
            } catch {
              case _: NotSerializableException =>
                // t is not serializable so just send the stacktrace
                ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
            }
          }
          setTaskFinishedAndClearInterruptStatus()
          execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
        } else {
          logInfo("Not reporting error to driver during JVM shutdown.")
        }
    
        // Don't forcibly exit unless the exception was inherently fatal, to avoid
        // stopping other tasks unnecessarily.
        if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
          uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
        }
    } finally {
      runningTasks.remove(taskId)
    }
    

    }

Task.run

final def run(
  taskAttemptId: Long,
  attemptNumber: Int,
  metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
  stageId,
  stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
  partitionId,
  taskAttemptId,
  attemptNumber,
  taskMemoryManager,
  localProperties,
  // 度量系统就是SparkEnv的度量对象
  metricsSystem,
  metrics)
TaskContext.setTaskContext(context)
// 记录运行任务的线程
taskThread = Thread.currentThread()

// 主要是更改TaskContext中的任务杀死原因的标记变量
// 以给线程发一次中断
if (_reasonIfKilled != null) {
  kill(interruptThread = false, _reasonIfKilled)
}

new CallerContext(
  "TASK",
  SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
  appId,
  appAttemptId,
  jobId,
  Option(stageId),
  Option(stageAttemptId),
  Option(taskAttemptId),
  Option(attemptNumber)).setCurrentContext()

try {
  runTask(context)
} catch {
  case e: Throwable =>
    // Catch all errors; run task failure callbacks, and rethrow the exception.
    try {
      context.markTaskFailed(e)
    } catch {
      case t: Throwable =>
        e.addSuppressed(t)
    }
    context.markTaskCompleted(Some(e))
    throw e
} finally {
  try {
    // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
    // one is no-op.
    context.markTaskCompleted(None)
  } finally {
    try {
      Utils.tryLogNonFatalError {
        // Release memory used by this thread for unrolling blocks
        // 释放内存快管理器中该任务使用的内存,最终是通过内存管理器来释放的
        // 实际上就是更新内存管理器内部的一些用于记录内存使用情况的簿记量
        // 真正的内存回收肯定还是有gc来完成的
        SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
        SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
          MemoryMode.OFF_HEAP)
        // Notify any tasks waiting for execution memory to be freed to wake up and try to
        // acquire memory again. This makes impossible the scenario where a task sleeps forever
        // because there are no other tasks left to notify it. Since this is safe to do but may
        // not be strictly necessary, we should revisit whether we can remove this in the
        // future.
        val memoryManager = SparkEnv.get.memoryManager
        // 内存释放之后,需要通知其他在等待内存资源的 线程
        memoryManager.synchronized { memoryManager.notifyAll() }
      }
    } finally {
      // Though we unset the ThreadLocal here, the context member variable itself is still
      // queried directly in the TaskRunner to check for FetchFailedExceptions.
      TaskContext.unset()
    }
  }
}
}
  • 创建一个TaskContextImpl,并设置到一个ThreadLocal变量中
  • 检查任务是否被杀死
  • 调用runTask方法执行实际的任务逻辑
  • 最后会释放在shuffle过程中申请的用于数据unroll的内存资源

所以,接下来我们要分析的肯定就是runTask方法,而这个方法是个抽象方法,由于ResultTask很简单,我就不再分析了,这里我重点分析一下ShuffleMapTask。

ShuffleMapTask.runTask

override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
// 反序列化RDD和shuffle, 关键的步骤
// 这里思考rdd和shuffle反序列化时,内部的SparkContext对象是怎么反序列化的
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
  ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L

var writer: ShuffleWriter[Any, Any] = null
try {
  // shuffle管理器
  val manager = SparkEnv.get.shuffleManager
  // 获取一个shuffle写入器
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  // 这里可以看到rdd计算的核心方法就是iterator方法
  // SortShuffleWriter的write方法可以分为几个步骤:
  // 将上游rdd计算出的数据(通过调用rdd.iterator方法)写入内存缓冲区,
  // 在写的过程中如果超过 内存阈值就会溢写磁盘文件,可能会写多个文件
  // 最后将溢写的文件和内存中剩余的数据一起进行归并排序后写入到磁盘中形成一个大的数据文件
  // 这个排序是先按分区排序,在按key排序
  // 在最后归并排序后写的过程中,没写一个分区就会手动刷写一遍,并记录下这个分区数据在文件中的位移
  // 所以实际上最后写完一个task的数据后,磁盘上会有两个文件:数据文件和记录每个reduce端partition数据位移的索引文件
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  // 主要是删除中间过程的溢写文件,向内存管理器释放申请的内存
  writer.stop(success = true).get
} catch {
  case e: Exception =>
    try {
      if (writer != null) {
        writer.stop(success = false)
      }
    } catch {
      case e: Exception =>
        log.debug("Could not stop writer", e)
    }
    throw e
}
}

这个方法还是大概逻辑还是很简单的,主要就是通过rdd的iterator方法获取当前task对应的分区的计算结果(结果一一个迭代器的形式返回)利用shuffleManager通过blockManager写入到文件block中,然后将block信息传回driver上报给BlockManagerMaster。
所以实际上重要的步骤有两个:通过RDD的计算链获取计算结果;将计算结果经过排序和分区写到文件中。
这里我先分析第二个步骤。

SortShuffleWriter.write

spark在2.0之后shuffle管理器改成了排序shuffle管理器,即SortShuffleManager,所以这里通过SortShuffleManager管理器获取到的在一般情况下都是SortShuffleWriter,当然在满足bypass条件(map端不需要combine,并且分区数小于200)的情况下会使用BypassMergeSortShuffleWriter。

override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
  // map端进行合并的情况,此时用户应该提供聚合器和顺序
  require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
  new ExternalSorter[K, V, C](
    context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
  // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
  // care whether the keys get sorted in each partition; that will be done on the reduce side
  // if the operation being run is sortByKey.
  new ExternalSorter[K, V, V](
    context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
// 将map数据全部写入排序器中,
// 这个过程中可能会生成多个溢写文件
sorter.insertAll(records)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
// mapId就是shuffleMap端RDD的partitionId
// 获取这个map分区的shuffle输出文件名
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
// 加一个uuid后缀
val tmp = Utils.tempFileWith(output)
try {
  val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
  // 这一步将溢写到的磁盘的文件和内存中的数据进行归并排序,
  // 并溢写到一个文件中,这一步写的文件是临时文件名
  val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
  // 这一步主要是写入索引文件,使用move方法原子第将临时索引和临时数据文件重命名为正常的文件名
  shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
  // 返回一个状态对象,包含shuffle服务Id和各个分区数据在文件中的位移
  mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
  if (tmp.exists() && !tmp.delete()) {
    logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
  }
}
}

总结一下这个方法的主要逻辑:

  • 首先获取一个排序器,并检查是否有map端的合并器
  • 将rdd计算结果数据写入排序器,过程中可能会溢写过个磁盘文件
  • 最后将多个碎小的溢写文件和内存缓冲区的数据进行归并排序,写到一个文件中
  • 将每个分区数据在文件中的偏移量写到一个索引文件中,用于reduce阶段拉取数据时使用
  • 返回一个MapStatus对象,封装了当前executor上的blockManager的id和每个分区在数据文件中的位移量

总结

本篇先分析到这里。剩下的代码都是属于排序器内部的对数据的排序和溢写文件的逻辑。这部分内容值得写一篇文章来单独分析。
总结一下任务在executor端的执行流程:

  • 首先executor端的rpc服务端点收到LaunchTask的消息,并对传过来的任务数据进行反序列化成TaskDescription
  • 将任务交给Executor对象运行
  • Executor根据传过来的TaskDescription对象创建一个TaskRunner对象,并放到线程池中运行。这里的线程池用的是Executors.newCachedThreadPool,空闲是不会有线程在跑
  • TaskRunner对任务进一步反序列化,调用Task.run方法执行任务运行逻辑
  • ShuffleMapTask类型的任务会将rdd计算结果数据经过排序合并之后写到一个文件中,并写一个索引文件
  • 任务运行完成后会更新一些任务统计量和度量系统中的一些统计量
  • 最后会根据结果序列化后的大小选择不同的方式将结果传回driver。

posted on 2019-06-04 00:48  _朱葛  阅读(5152)  评论(0编辑  收藏  举报