spark源码分析, 任务反序列化及执行
1 ==> 接受消息,org.apache.spark.executor.CoarseGrainedExecutorBackend#receive
case LaunchTask(data) => if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskDesc) }
2. ==> org.apache.spark.executor.Executor#launchTask
// Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) }
3. ==>org.apache.spark.executor.Executor.TaskRunner#run
override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) //下载依赖 updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
//反序列化得到真正的 task task = ser.deserialize[Task[Any]](taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) task.localProperties = taskDescription.properties task.setTaskMemoryManager(taskMemoryManager) val value = Utils.tryWithSafeFinally { val res = task.run( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res } { val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() } //处理执行结果 val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() val valueBytes = resultSer.serialize(value) val afterSerialization = System.currentTimeMillis() // Note: accumulator updates must be collected after TaskMetrics is updated val accumUpdates = task.collectAccumulatorUpdates() // TODO: do not serialize value twice val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit() // directSend = sending directly back to the driver val serializedResult: ByteBuffer = { if (maxResultSize > 0 && resultSize > maxResultSize) { logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) } else if (resultSize > maxDirectResultSize) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, new ChunkedByteBuffer(serializedDirectResult.duplicate()), StorageLevel.MEMORY_AND_DISK_SER) logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") serializedDirectResult } } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) }
==> org.apache.spark.executor.Executor#updateDependencies
/** * Download any missing dependencies if we receive a new set of files and JARs from the * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) { lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) // Fetch file with useCache mode, close cache for local mode. Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars) { val localName = new URI(name).getPath.split("/").last val currentTimeStamp = currentJars.get(name) .orElse(currentJars.get(localName)) .getOrElse(-1L) if (currentTimeStamp < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) // Fetch file with useCache mode, close cache for local mode. Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentJars(name) = timestamp // Add it to our class loader val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL if (!urlClassLoader.getURLs().contains(url)) { logInfo("Adding " + url + " to class loader") urlClassLoader.addURL(url) } } } } }
==> org.apache.spark.scheduler.Task#run
final def run( taskAttemptId: Long, attemptNumber: Int, metricsSystem: MetricsSystem): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) val taskContext = new TaskContextImpl( stageId, stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, metrics) context = if (isBarrier) { new BarrierTaskContext(taskContext) } else { taskContext } TaskContext.setTaskContext(context) taskThread = Thread.currentThread() if (_reasonIfKilled != null) { kill(interruptThread = false, _reasonIfKilled) } new CallerContext( "TASK", SparkEnv.get.conf.get(APP_CALLER_CONTEXT), appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() try { //这个类只是一个模板类或者抽象类, 具体实现类分为ResultTask, ShuffleMapTask 两种 runTask(context) } }
==>org.apache.spark.scheduler.ShuffleMapTask#runTask
ShuffleMapTask将rdd的元素,切分为多个bucket, 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner
ShuffleMapTask 核心方法是 RDD.iterator[底层调用 compute 方法(fn(context,index,partition))],
执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter在经过HashPartitioner写入对应的分区中
// ShuffleMapTask将rdd的元素,切分为多个bucket // 基于ShuffleDependency指定的partitioner,默认就是HashPartitioner private[spark] class ShuffleMapTask( ... // ShuffleMapTask的 runTask 有 MapStatus返回值 override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L // 对task要处理的数据,做反序列化操作 val ser = SparkEnv.get.closureSerializer.newInstance() //获得 RDD val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L var writer: ShuffleWriter[Any, Any] = null try { // 拿到shuffleManager val manager = SparkEnv.get.shuffleManager // 拿到shuffleWriter writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) // 核心逻辑,调用rdd的iterator方法,并且传入了当前要处理的partition // 执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter // 在经过HashPartitioner写入对应的分区中 writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) // 返回结果 MapStatus ,里面封装了ShuffleMapTask存储在哪里,其实就是BlockManager相关信息 writer.stop(success = true).get } } ... }
==> org.apache.spark.scheduler.ResultTask#runTask
override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. val threadMXBean = ManagementFactory.getThreadMXBean val deserializeStartTime = System.currentTimeMillis() val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L //直接调用用户自定义函数 func(context, rdd.iterator(partition, context)) }
==> org.apache.spark.rdd.RDD#iterator
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
//结果不需要存储 if (storageLevel != StorageLevel.NONE) { getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } }
==> org.apache.spark.rdd.RDD#computeOrReadCheckpoint
/** * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. */ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { if (isCheckpointedAndMaterialized) { firstParent[T].iterator(split, context) } else { //核心方法, 此方法为虚方法,具体实现由具体 RDD 子类实现,如 MapPartitionsRDD,JdbcRDD等 compute(split, context) } }
demo:
class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false, isFromBarrier: Boolean = false, isOrderSensitive: Boolean = false) extends RDD[U](prev) { override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) } class JdbcRDD[T: ClassTag]( sc: SparkContext, getConnection: () => Connection, sql: String, lowerBound: Long, upperBound: Long, numPartitions: Int, mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) extends RDD[T](sc, Nil) with Logging { override def getPartitions: Array[Partition] = { // bounds are inclusive, hence the + 1 here and - 1 on end val length = BigInt(1) + upperBound - lowerBound (0 until numPartitions).map { i => val start = lowerBound + ((i * length) / numPartitions) val end = lowerBound + (((i + 1) * length) / numPartitions) - 1 new JdbcPartition(i, start.toLong, end.toLong) }.toArray } override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] { context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) val url = conn.getMetaData.getURL val rs = stmt.executeQuery() override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { finished = true null.asInstanceOf[T] } } override def close() { } } }