spark1.1.0源码阅读-executor

1. executor上执行launchTask

1   def launchTask(
2       context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
3     val tr = new TaskRunner(context, taskId, taskName, serializedTask)
4     runningTasks.put(taskId, tr)
5     threadPool.execute(tr)
6   }

2. executor上执行TaskRunner的run

 1  class TaskRunner(
 2       execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
 3     extends Runnable {
 4 
 5     @volatile private var killed = false
 6     @volatile var task: Task[Any] = _
 7     @volatile var attemptedTask: Option[Task[Any]] = None
 8 
 9     def kill(interruptThread: Boolean) {
10       logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
11       killed = true
12       if (task != null) {
13         task.kill(interruptThread)
14       }
15     }
16 
17     override def run() {
18       val startTime = System.currentTimeMillis()
19       SparkEnv.set(env)
20       Thread.currentThread.setContextClassLoader(replClassLoader)
21       val ser = SparkEnv.get.closureSerializer.newInstance()
22       logInfo(s"Running $taskName (TID $taskId)")
23       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
24       var taskStart: Long = 0
25       def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
26       val startGCTime = gcTime
27 
28       try {
29         SparkEnv.set(env)
30         Accumulators.clear()
31         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)  //反序列化出 taskFiles,taskJars,taskBytes
32         updateDependencies(taskFiles, taskJars)
33         task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)  //反序列化出task对象
34 
35         // If this task has been killed before we deserialized it, let's quit now. Otherwise,
36         // continue executing the task.
37         if (killed) {
38           // Throw an exception rather than returning, because returning within a try{} block
39           // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
40           // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
41           // for the task.
42           throw new TaskKilledException
43         }
44 
45         attemptedTask = Some(task)
46         logDebug("Task " + taskId + "'s epoch is " + task.epoch)
47         env.mapOutputTracker.updateEpoch(task.epoch)
48 
49         // Run the actual task and measure its runtime.
50         taskStart = System.currentTimeMillis()
51         val value = task.run(taskId.toInt)
52         val taskFinish = System.currentTimeMillis()
53 
54         // If the task has been killed, let's fail it.
55         if (task.killed) {
56           throw new TaskKilledException
57         }

3. task.run

 1 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
 2 
 3   final def run(attemptId: Long): T = {
 4     context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
 5     context.taskMetrics.hostname = Utils.localHostName()
 6     taskThread = Thread.currentThread()
 7     if (_killed) {
 8       kill(interruptThread = false)
 9     }
10     runTask(context)
11   }

4. task是抽象类,对于具体的类(resultTask和shuffleMapTask)会执行相应的runTask。

a. resultTask

 1   override def runTask(context: TaskContext): U = {
 2     // Deserialize the RDD and the func using the broadcast variables.
 3     val ser = SparkEnv.get.closureSerializer.newInstance()
 4     val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
 5       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
 6 
 7     metrics = Some(context.taskMetrics)
 8     try {
 9       func(context, rdd.iterator(partition, context))
10     } finally {
11       context.markTaskCompleted()
12     }
13   }

b. shuffleMapTask

 1   override def runTask(context: TaskContext): MapStatus = {
 2     // Deserialize the RDD using the broadcast variable.
 3     val ser = SparkEnv.get.closureSerializer.newInstance()
 4     val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
 5       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
 6 
 7     metrics = Some(context.taskMetrics)
 8     var writer: ShuffleWriter[Any, Any] = null
 9     try {
10       val manager = SparkEnv.get.shuffleManager
11       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
12       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
13       return writer.stop(success = true).get
14     } catch {
15       case e: Exception =>
16         if (writer != null) {
17           writer.stop(success = false)
18         }
19         throw e
20     } finally {
21       context.markTaskCompleted()
22     }
23   }
 1   /** Write a bunch of records to this task's output */
 2   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
 3     val iter = if (dep.aggregator.isDefined) {
 4       if (dep.mapSideCombine) {
 5         dep.aggregator.get.combineValuesByKey(records, context)
 6       } else {
 7         records
 8       }
 9     } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
10       throw new IllegalStateException("Aggregator is empty for map-side combine")
11     } else {
12       records
13     }
14 
15     for (elem <- iter) {
16       val bucketId = dep.partitioner.getPartition(elem._1)
17       shuffle.writers(bucketId).write(elem)
18     }
19   }

 

 1   /**
 2    * Get a ShuffleWriterGroup for the given map task, which will register it as complete
 3    * when the writers are closed successfully
 4    */
 5   def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
 6       writeMetrics: ShuffleWriteMetrics) = {
 7     new ShuffleWriterGroup {
 8       shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
 9       private val shuffleState = shuffleStates(shuffleId)
10       private var fileGroup: ShuffleFileGroup = null
11 
12       val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
13         fileGroup = getUnusedFileGroup()
14         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
15           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
16           blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
17             writeMetrics)
18         }
19       } else {
20         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
21           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
22           val blockFile = blockManager.diskBlockManager.getFile(blockId)
23           // Because of previous failures, the shuffle file may already exist on this machine.
24           // If so, remove it.
25           if (blockFile.exists) {
26             if (blockFile.delete()) {
27               logInfo(s"Removed existing shuffle file $blockFile")
28             } else {
29               logWarning(s"Failed to remove existing shuffle file $blockFile")
30             }
31           }
32           blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
33         }
34       }

 

posted on 2014-12-11 22:36  Torstan  阅读(302)  评论(0编辑  收藏  举报

导航