Spark 源码分析 -- Task
Task是介于DAGScheduler和TaskScheduler中间的接口
在DAGScheduler, 需要把DAG中的每个stage的每个partitions封装成task
最终把taskset提交给TaskScheduler
/** * A task to execute on a worker node. */ private[spark] abstract class Task[T](val stageId: Int) extends Serializable { def run(attemptId: Long): T //Task的核心函数 def preferredLocations: Seq[TaskLocation] = Nil //Spark关注locality,可以选择该task运行的location var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. var metrics: Option[TaskMetrics] = None }
TaskContext
用于记录TaskMetrics和在Task中用到的callback
比如对于HadoopRDD, task完成时需要close input stream
package org.apache.spark
class TaskContext( val stageId: Int, val splitId: Int, val attemptId: Long, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty() //TaskMetrics封装了task执行时一些指标和数据 ) extends Serializable { @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. def addOnCompleteCallback(f: () => Unit) { onCompleteCallbacks += f } def executeOnCompleteCallbacks() { onCompleteCallbacks.foreach{_()} } }
ResultTask
对应于Result Stage直接产生结果
package org.apache.spark.scheduler
private[spark] class ResultTask[T, U]( stageId: Int, var rdd: RDD[T], var func: (TaskContext, Iterator[T]) => U, var partition: Int, @transient locs: Seq[TaskLocation], var outputId: Int) extends Task[U](stageId) with Externalizable { override def run(attemptId: Long): U = { // 对于resultTask, run就是返回执行的结果, 比如count值 val context = new TaskContext(stageId, partition, attemptId, runningLocally = false) metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(split, context)) // 直接就是对RDD的iterator调用func, 比如count函数 } finally { context.executeOnCompleteCallbacks() } } }
ShuffleMapTask
对应于ShuffleMap Stage, 产生的结果作为其他stage的输入
package org.apache.spark.scheduler
private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_], var partition: Int, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId) with Externalizable with Logging { override def run(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions // 从ShuffleDependency的partitioner中获取到shuffle目标partition的个数 val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false) metrics = Some(taskContext.taskMetrics) val blockManager = SparkEnv.get.blockManager // shuffle需要借助blockManager来完成 var shuffle: ShuffleBlocks = null var buckets: ShuffleWriterGroup = null try { // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) // 创建shuffleBlockManager, 参数是shuffleId和目标partitions数目 buckets = shuffle.acquireWriters(partition) // 生成shuffle目标buckets(对应于partition) // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { // 从RDD中取出每个elem数据 val pair = elem.asInstanceOf[Product2[Any, Any]] val bucketId = dep.partitioner.getPartition(pair._1) // 根据pair的key进行shuffle, 得到目标bucketid buckets.writers(bucketId).write(pair) // 将pair数据写入bucket }
// Commit这些buckets到block, 其他的RDD会从通过shuffleid找到这些block, 并读取数据 // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => // 计算所有buckets写入文件data的size总和(压缩值) writer.commit() writer.close() val size = writer.size() totalBytes += size MapOutputTracker.compressSize(size) } // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) return new MapStatus(blockManager.blockManagerId, compressedSizes) // 返回值为MapStatus, 包含blockManagerId和写入的data size, 会被注册到MapOutputTracker } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. if (buckets != null) { buckets.writers.foreach(_.revertPartialWrites()) } throw e } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && buckets != null) { shuffle.releaseWriters(buckets) } // Execute the callbacks on task completion. taskContext.executeOnCompleteCallbacks() } }
TaskSet
用于封装一个stage的所有的tasks, 以提交给TaskScheduler
package org.apache.spark.scheduler /** * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. */ private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int, val properties: Properties) { val id: String = stageId + "." + attempt override def toString: String = "TaskSet " + id }