pyspark对应的scala代码PythonRDD对象

pyspark jvm端的scala代码PythonRDD

代码版本为 spark 2.2.0

1.PythonRDD.object

这个静态类是pyspark的一些基础入口

// 这里不会把这个类全部内容都介绍,因为大部分都是静态接口,被pyspark的同名代码调用
// 这里介绍几个主要函数
// 在pyspark的RDD中作为所有action的基础的collect方法调用的collectAndServer方法也在这个对象中被定义
private[spark] object PythonRDD extends Logging {
  //被pyspark.SparkContext.runJob调用
  //提供rdd.collect的功能,提交job
  def runJob(
      sc: SparkContext,
      rdd: JavaRDD[Array[Byte]],
      partitions: JArrayList[Int]): Int = {
    type ByteArray = Array[Byte]
    type UnrolledPartition = Array[ByteArray]
    val allPartitions: Array[UnrolledPartition] =
      sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala)
    val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
    serveIterator(flattenedPartition.iterator,
      s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}")
  }

  // 整个pyspark.RDD 的action都是在这个函数中被触发的
  // pyspark.RDD 的collect通过调用这个方法被触发rdd的执行和任务提交
  def collectAndServe[T](rdd: RDD[T]): Int = {
    //参数rdd 即pyspark中RDD里的_jrdd, 对应的是scala里数据源rdd或pythonRDD
    // 这里rdd.collect() 触发了任务开始运行
    serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
  }
  
  //这个函数的作用是将计算结果写入到本地socket当中,再在pyspark里读取本地socket获取结果
  def serveIterator[T](items: Iterator[T], threadName: String): Int = {
    // 可以看见socket在本地随机端口和localhost上建立出来的
    val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
    // Close the socket if no connection in 3 seconds
    serverSocket.setSoTimeout(3000)

    // 这里启动一个线程负责将结果写入到socket中
    new Thread(threadName) {
      setDaemon(true)
      override def run() {
        try {
          val sock = serverSocket.accept()
          val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
          Utils.tryWithSafeFinally {
            //具体负责写的是此函数,此函数主要做一些类型和序列化工作
            writeIteratorToStream(items, out)
          } {
            out.close()
          }
        } catch {
          case NonFatal(e) =>
            logError(s"Error while sending iterator", e)
        } finally {
          serverSocket.close()
        }
      }
    }.start()

    // 最后返回此socket的网络端口, 这样pyspark里就可以通过此端口读取数据
    serverSocket.getLocalPort
  }

  // 此函数负责写入数据结果
  // 做一些类型检查和对应的序列化工作
  // PythonRunner中WriterThread写入数据时使用的也是此函数
  def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {

    def write(obj: Any): Unit = obj match {
      case null =>
        dataOut.writeInt(SpecialLengths.NULL)
      case arr: Array[Byte] =>
        dataOut.writeInt(arr.length)
        dataOut.write(arr)
      case str: String =>
        writeUTF(str, dataOut)
      case stream: PortableDataStream =>
        write(stream.toArray())
      case (key, value) =>
        write(key)
        write(value)
      case other =>
        throw new SparkException("Unexpected element type " + other.getClass)
    }

    iter.foreach(write)
  }


}

posted @ 2018-05-16 20:44  vv.past  阅读(703)  评论(0编辑  收藏  举报