pyspark对应的scala代码PythonRDD类
pyspark jvm端的scala代码PythonRDD
代码版本为 spark 2.2.0
1.PythonRDD.class
这个rdd类型是python能接入spark的关键
//这是一个标准的RDD实现,实现对应的compute,partitioner,getPartitions等方法
//这个PythonRDD就是pyspark里PipelinedRDD里_jrdd属性方法返回的东西
//parent就是PipelinedRDD里传递进来的_prev_jrdd,是最初构建的数据源RDD
private[spark] class PythonRDD(
parent: RDD[_], //这个parentRDD是关键,python使用spark的所有数据来源都从这里来的
func: PythonFunction, //这个是用户实现的python计算逻辑
preservePartitoning: Boolean)
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
override def getPartitions: Array[Partition] = firstParent.partitions
override val partitioner: Option[Partitioner] = {
if (preservePartitoning) firstParent.partitioner else None
}
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
//调用PythonRunner执行此处任务逻辑
//这里这个PythonRunner跟spark-submit时执行的PythonRunner不是同一个东西
val runner = PythonRunner(func, bufferSize, reuse_worker)
//执行runner的计算逻辑,第一个参数是spark数据源rdd的计算结果
//firstParent.iterator会触发parent 这个rdd的计算,返回计算结果
//这里第一个参数的rdd跟pyspark中RDD里的_jrdd是同一个东西
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
2.PythonRunner.class
这个类是rdd内部执行计算时的实体计算类,并不是代码提交时那个启动py4j的PythonRunner
/*
* 这个类做了三件事
* 1.启动pyspark.daemon 接收task启动work执行接收到的task
* 2.启动writerThread 将数据源的计算结果写到pyspark.work中
* 3.从pyspark.work中拉取执行结果
*
* writerThread写的数据就是pyspark中_jrdd计算出来的结果,也就是数据源rdd的数据
*/
private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
isUDF: Boolean,
argOffsets: Array[Array[Int]])
extends Logging {
require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
//python执行的环境和命令
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer
private val accumulator = funcs.head.funcs.head.accumulator
def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
//创建pyspark 的work进程,底层执行的是pyspark.daemon
//这个方法保证一次任务只启动一个pyspark.daemon
//返回结果是跟work通信用的socket
//具体分析将在其它部分记录
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
@volatile var released = false
// 创建writerThread,把数据源数据写到socket,发送到pyspark.work
val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
//注册task完成监听,完成后停止writerThread线程
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
if (!reuse_worker || !released) {
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}
writerThread.start()
new MonitorThread(env, worker, context).start()
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
// 创建拉取pyspark.work执行结果的迭代器
val stdoutIterator = new Iterator[Array[Byte]] {
override def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
_nextObj = read()
}
obj
}
private def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, StandardCharsets.UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator.add(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
null
}
} catch {
case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
}
var _nextObj = read()
override def hasNext: Boolean = _nextObj != null
}
//返回这个拉取数据结果的迭代器
new InterruptibleIterator(context, stdoutIterator)
}
/**
* WriterThread 线程的实现代码
*/
class WriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
@volatile private var _exception: Exception = null
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
/** Contains the exception thrown while writing the parent iterator to the Python process. */
def exception: Option[Exception] = Option(_exception)
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
assert(context.isCompleted)
this.interrupt()
}
// 主要逻辑在run里,把数据源rdd的执行结果写进去
// 把广播变量和环境,以及python的执行逻辑代码写进去
// 把需要计算的数据源数据写进去
override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// Write out the TaskContextInfo
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
dataOut.writeLong(context.taskAttemptId())
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
for (bid <- toRemove) {
// remove the broadcast from worker
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
dataOut.flush()
// Serialized command:
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
offsets.foreach { offset =>
dataOut.writeInt(offset)
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
}
}
}
// 监控task是不是还在执行
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
extends Thread(s"Worker Monitor for $pythonExec") {
setDaemon(true)
override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}
}