spark1.1.0源码阅读-executor
1. executor上执行launchTask
1 def launchTask( 2 context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { 3 val tr = new TaskRunner(context, taskId, taskName, serializedTask) 4 runningTasks.put(taskId, tr) 5 threadPool.execute(tr) 6 }
2. executor上执行TaskRunner的run
1 class TaskRunner( 2 execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) 3 extends Runnable { 4 5 @volatile private var killed = false 6 @volatile var task: Task[Any] = _ 7 @volatile var attemptedTask: Option[Task[Any]] = None 8 9 def kill(interruptThread: Boolean) { 10 logInfo(s"Executor is trying to kill $taskName (TID $taskId)") 11 killed = true 12 if (task != null) { 13 task.kill(interruptThread) 14 } 15 } 16 17 override def run() { 18 val startTime = System.currentTimeMillis() 19 SparkEnv.set(env) 20 Thread.currentThread.setContextClassLoader(replClassLoader) 21 val ser = SparkEnv.get.closureSerializer.newInstance() 22 logInfo(s"Running $taskName (TID $taskId)") 23 execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) 24 var taskStart: Long = 0 25 def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum 26 val startGCTime = gcTime 27 28 try { 29 SparkEnv.set(env) 30 Accumulators.clear() 31 val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) //反序列化出 taskFiles,taskJars,taskBytes 32 updateDependencies(taskFiles, taskJars) 33 task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) //反序列化出task对象 34 35 // If this task has been killed before we deserialized it, let's quit now. Otherwise, 36 // continue executing the task. 37 if (killed) { 38 // Throw an exception rather than returning, because returning within a try{} block 39 // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl 40 // exception will be caught by the catch block, leading to an incorrect ExceptionFailure 41 // for the task. 42 throw new TaskKilledException 43 } 44 45 attemptedTask = Some(task) 46 logDebug("Task " + taskId + "'s epoch is " + task.epoch) 47 env.mapOutputTracker.updateEpoch(task.epoch) 48 49 // Run the actual task and measure its runtime. 50 taskStart = System.currentTimeMillis() 51 val value = task.run(taskId.toInt) 52 val taskFinish = System.currentTimeMillis() 53 54 // If the task has been killed, let's fail it. 55 if (task.killed) { 56 throw new TaskKilledException 57 }
3. task.run
1 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { 2 3 final def run(attemptId: Long): T = { 4 context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) 5 context.taskMetrics.hostname = Utils.localHostName() 6 taskThread = Thread.currentThread() 7 if (_killed) { 8 kill(interruptThread = false) 9 } 10 runTask(context) 11 }
4. task是抽象类,对于具体的类(resultTask和shuffleMapTask)会执行相应的runTask。
a. resultTask
1 override def runTask(context: TaskContext): U = { 2 // Deserialize the RDD and the func using the broadcast variables. 3 val ser = SparkEnv.get.closureSerializer.newInstance() 4 val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( 5 ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) 6 7 metrics = Some(context.taskMetrics) 8 try { 9 func(context, rdd.iterator(partition, context)) 10 } finally { 11 context.markTaskCompleted() 12 } 13 }
b. shuffleMapTask
1 override def runTask(context: TaskContext): MapStatus = { 2 // Deserialize the RDD using the broadcast variable. 3 val ser = SparkEnv.get.closureSerializer.newInstance() 4 val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( 5 ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) 6 7 metrics = Some(context.taskMetrics) 8 var writer: ShuffleWriter[Any, Any] = null 9 try { 10 val manager = SparkEnv.get.shuffleManager 11 writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) 12 writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) 13 return writer.stop(success = true).get 14 } catch { 15 case e: Exception => 16 if (writer != null) { 17 writer.stop(success = false) 18 } 19 throw e 20 } finally { 21 context.markTaskCompleted() 22 } 23 }
1 /** Write a bunch of records to this task's output */ 2 override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { 3 val iter = if (dep.aggregator.isDefined) { 4 if (dep.mapSideCombine) { 5 dep.aggregator.get.combineValuesByKey(records, context) 6 } else { 7 records 8 } 9 } else if (dep.aggregator.isEmpty && dep.mapSideCombine) { 10 throw new IllegalStateException("Aggregator is empty for map-side combine") 11 } else { 12 records 13 } 14 15 for (elem <- iter) { 16 val bucketId = dep.partitioner.getPartition(elem._1) 17 shuffle.writers(bucketId).write(elem) 18 } 19 }
1 /** 2 * Get a ShuffleWriterGroup for the given map task, which will register it as complete 3 * when the writers are closed successfully 4 */ 5 def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, 6 writeMetrics: ShuffleWriteMetrics) = { 7 new ShuffleWriterGroup { 8 shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) 9 private val shuffleState = shuffleStates(shuffleId) 10 private var fileGroup: ShuffleFileGroup = null 11 12 val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { 13 fileGroup = getUnusedFileGroup() 14 Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => 15 val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) 16 blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, 17 writeMetrics) 18 } 19 } else { 20 Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => 21 val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) 22 val blockFile = blockManager.diskBlockManager.getFile(blockId) 23 // Because of previous failures, the shuffle file may already exist on this machine. 24 // If so, remove it. 25 if (blockFile.exists) { 26 if (blockFile.delete()) { 27 logInfo(s"Removed existing shuffle file $blockFile") 28 } else { 29 logWarning(s"Failed to remove existing shuffle file $blockFile") 30 } 31 } 32 blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) 33 } 34 }