pyspark对应的scala代码PythonRDD对象
pyspark jvm端的scala代码PythonRDD
代码版本为 spark 2.2.0
1.PythonRDD.object
这个静态类是pyspark的一些基础入口
// 这里不会把这个类全部内容都介绍,因为大部分都是静态接口,被pyspark的同名代码调用
// 这里介绍几个主要函数
// 在pyspark的RDD中作为所有action的基础的collect方法调用的collectAndServer方法也在这个对象中被定义
private[spark] object PythonRDD extends Logging {
//被pyspark.SparkContext.runJob调用
//提供rdd.collect的功能,提交job
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int]): Int = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
serveIterator(flattenedPartition.iterator,
s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}")
}
// 整个pyspark.RDD 的action都是在这个函数中被触发的
// pyspark.RDD 的collect通过调用这个方法被触发rdd的执行和任务提交
def collectAndServe[T](rdd: RDD[T]): Int = {
//参数rdd 即pyspark中RDD里的_jrdd, 对应的是scala里数据源rdd或pythonRDD
// 这里rdd.collect() 触发了任务开始运行
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
//这个函数的作用是将计算结果写入到本地socket当中,再在pyspark里读取本地socket获取结果
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
// 可以看见socket在本地随机端口和localhost上建立出来的
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)
// 这里启动一个线程负责将结果写入到socket中
new Thread(threadName) {
setDaemon(true)
override def run() {
try {
val sock = serverSocket.accept()
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
//具体负责写的是此函数,此函数主要做一些类型和序列化工作
writeIteratorToStream(items, out)
} {
out.close()
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator", e)
} finally {
serverSocket.close()
}
}
}.start()
// 最后返回此socket的网络端口, 这样pyspark里就可以通过此端口读取数据
serverSocket.getLocalPort
}
// 此函数负责写入数据结果
// 做一些类型检查和对应的序列化工作
// PythonRunner中WriterThread写入数据时使用的也是此函数
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
def write(obj: Any): Unit = obj match {
case null =>
dataOut.writeInt(SpecialLengths.NULL)
case arr: Array[Byte] =>
dataOut.writeInt(arr.length)
dataOut.write(arr)
case str: String =>
writeUTF(str, dataOut)
case stream: PortableDataStream =>
write(stream.toArray())
case (key, value) =>
write(key)
write(value)
case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}
iter.foreach(write)
}
}