一文读懂 超简单的spark structured stream 源码解读
为了让大家理解structured stream的运行流程,我将根据一个代码例子,讲述structured stream的基本运行流程和原理。
下面是一段简单的代码:
1 val spark = SparkSession 2 .builder 3 .appName("StructuredNetworkWordCount") 4 .master("local[4]") 5 6 .getOrCreate() 7 spark.conf.set("spark.sql.shuffle.partitions", 4) 8 9 import spark.implicits._ 10 val words = spark.readStream 11 .format("socket") 12 .option("host", "localhost") 13 .option("port", 9999) 14 .load() 15 16 val df1 = words.as[String] 17 .flatMap(_.split(" ")) 18 .toDF("word") 19 .groupBy("word") 20 .count() 21 22 df1.writeStream 23 .outputMode("complete") 24 .format("console") 25 .trigger(ProcessingTime(10)) 26 .start() 27 28 spark.streams.awaitAnyTermination()
这段代码就是单词计数。先从一个socket数据源读入数据,然后以" " 为分隔符把一行文本转换成单词的DataSet,然后转换成有标签("word")的DataFrame,接着按word列进行分组,聚合计算每个word的个数。最后输出到控制台,以10秒为批处理执行周期。
现在来分析它的原理。spark的逻辑里面有一个惰性计算的概念,以上面的例子来说,在第22行代码以前,程序都不会对数据进行真正的计算,而是将计算的公式(或者函数)保存在DataFrame里面,在22行开始的writeStream.start调用后才开始真正的计算。为什么?
因为:
这可以让spark内核做一些优化。
例如:
数据库中存放着人的名字和年龄,我想要在控制台打印出前十个年龄大于20岁的人的名字,那么我的spark代码会这么写:
1 df.fileter{row=> 2 row._2>20} 3 .show(10)
假如说我每执行一行代码就进行一次计算,那么在第二行的时候,我就会把df里面所有的数据进行过滤,筛选出其中年龄大于20的,然后在第3行执行的时候,从第2行里面的结果中选前面10个进行打印。
看出问题了么?这里的输出仅仅只需要10个年龄大于20的人,但是我却把所有人都筛选了一遍,其实我只需要筛选出10个,后面的就不必要筛选了。这就是spark的惰性计算进行优化的地方。
在spark的计算中,在真正的输出函数之前,都不会进行真正的计算,而会在输出函数之前进行优化后再进行计算。我们来看源代码。
这里我贴的是structured stream每次批处理周期到达时会运行的代码:
1 private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { 2 // Request unprocessed data from all sources. 3 newData = reportTimeTaken("getBatch") { 4 availableOffsets.flatMap { 5 case (source, available) 6 if committedOffsets.get(source).map(_ != available).getOrElse(true) => 7 val current = committedOffsets.get(source) 8 val batch = source.getBatch(current, available) 9 logDebug(s"Retrieving data from $source: $current -> $available") 10 Some(source -> batch) 11 case _ => None 12 } 13 } 14 15 // A list of attributes that will need to be updated. 16 var replacements = new ArrayBuffer[(Attribute, Attribute)] 17 // Replace sources in the logical plan with data that has arrived since the last batch. 18 val withNewSources = logicalPlan transform { 19 case StreamingExecutionRelation(source, output) => 20 newData.get(source).map { data => 21 val newPlan = data.logicalPlan 22 assert(output.size == newPlan.output.size, 23 s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + 24 s"${Utils.truncatedString(newPlan.output, ",")}") 25 replacements ++= output.zip(newPlan.output) 26 newPlan 27 }.getOrElse { 28 LocalRelation(output) 29 } 30 } 31 32 // Rewire the plan to use the new attributes that were returned by the source. 33 val replacementMap = AttributeMap(replacements) 34 val triggerLogicalPlan = withNewSources transformAllExpressions { 35 case a: Attribute if replacementMap.contains(a) => replacementMap(a) 36 case ct: CurrentTimestamp => 37 CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, 38 ct.dataType) 39 case cd: CurrentDate => 40 CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, 41 cd.dataType, cd.timeZoneId) 42 } 43 44 reportTimeTaken("queryPlanning") { 45 lastExecution = new IncrementalExecution( 46 sparkSessionToRunBatch, 47 triggerLogicalPlan, 48 outputMode, 49 checkpointFile("state"), 50 currentBatchId, 51 offsetSeqMetadata) 52 lastExecution.executedPlan // Force the lazy generation of execution plan 53 } 54 55 val nextBatch = 56 new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) 57 58 reportTimeTaken("addBatch") { 59 sink.addBatch(currentBatchId, nextBatch) 60 } 61 62 awaitBatchLock.lock() 63 try { 64 // Wake up any threads that are waiting for the stream to progress. 65 awaitBatchLockCondition.signalAll() 66 } finally { 67 awaitBatchLock.unlock() 68 } 69 }
其实很简单,在第58以前都是在解析用户代码,生成logicPlan,优化logicPlan,生成批处理类。第47行的triggerLogicalPlan就是最终优化后的用户逻辑,它被封装在了一个IncrementalExecution类中,这个类连同sparkSessionToRunBatch(运行环境)和RowEncoder(序列化类)一起构成一个新的DataSet,这个DataSet就是最终要发送到worker节点进行执行的代码。第59行代码就是在将它加入到准备发送代码的队列中。我们继续看一段代码,由于我们使用console作为数据下游(sink)所以看看console的addBatch代码:
1 override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { 2 val batchIdStr = if (batchId <= lastBatchId) { 3 s"Rerun batch: $batchId" 4 } else { 5 lastBatchId = batchId 6 s"Batch: $batchId" 7 } 8 9 // scalastyle:off println 10 println("-------------------------------------------") 11 println(batchIdStr) 12 println("-------------------------------------------") 13 // scalastyle:off println 14 data.sparkSession.createDataFrame( 15 data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) 16 .show(numRowsToShow, isTruncated) 17 }
关键代码在16行.show函数,show函数是一个真正的action,在这之前都是一些算子的封装,我们看show的代码:
1 private[sql] def showString(_numRows: Int, truncate: Int = 20): String = { 2 val numRows = _numRows.max(0) 3 val takeResult = toDF().take(numRows + 1) 4 val hasMoreData = takeResult.length > numRows 5 val data = takeResult.take(numRows)
第3行进入take:
def take(n: Int): Array[T] = head(n)
def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { 2 try { 3 qe.executedPlan.foreach { plan => 4 plan.resetMetrics() 5 } 6 val start = System.nanoTime() 7 val result = SQLExecution.withNewExecutionId(sparkSession, qe) { 8 action(qe.executedPlan) 9 } 10 val end = System.nanoTime() 11 sparkSession.listenerManager.onSuccess(name, qe, end - start) 12 result 13 } catch { 14 case e: Exception => 15 sparkSession.listenerManager.onFailure(name, qe, e) 16 throw e 17 } 18 }
这个函数名就告诉我们,这是真正计算要开始了,第7行代码一看就是准备发送代码序列了:
1 def withNewExecutionId[T]( 2 sparkSession: SparkSession, 3 queryExecution: QueryExecution)(body: => T): T = { 4 val sc = sparkSession.sparkContext 5 val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) 6 if (oldExecutionId == null) { 7 val executionId = SQLExecution.nextExecutionId 8 sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) 9 executionIdToQueryExecution.put(executionId, queryExecution) 10 val r = try { 11 // sparkContext.getCallSite() would first try to pick up any call site that was previously 12 // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on 13 // streaming queries would give us call site like "run at <unknown>:0" 14 val callSite = sparkSession.sparkContext.getCallSite() 15 16 sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( 17 executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, 18 SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) 19 try { 20 body 21 } finally { 22 sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( 23 executionId, System.currentTimeMillis())) 24 } 25 } finally { 26 executionIdToQueryExecution.remove(executionId) 27 sc.setLocalProperty(EXECUTION_ID_KEY, null) 28 } 29 r 30 } else { 31 // Don't support nested `withNewExecutionId`. This is an example of the nested 32 // `withNewExecutionId`: 33 // 34 // class DataFrame { 35 // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() } 36 // } 37 // 38 // Note: `collect` will call withNewExecutionId 39 // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan" 40 // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution 41 // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run, 42 // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. 43 // 44 // A real case is the `DataFrame.count` method. 45 throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set") 46 } 47 }
你看第16行,就是在发送数据,包括用户优化后的逻辑,批处理的id,时间戳等等。worker接收到这个事件后根据logicalPlan里面的逻辑就开始干活了。这就是一个很基本很简单的流程,对于spark入门还是挺有帮助的吧。
posted on 2018-03-02 18:11 skyer1992 阅读(1834) 评论(0) 编辑 收藏 举报