Spark记录(四):Dataset.count()方法源码剖析
因最近工作中涉及较多的Spark相关功能,所以趁周末闲来无事,研读一下Dataset的count方法。Spark版本3.2.0
1、方法入口:
def count(): Long = withAction("count", groupBy().count().queryExecution) { plan => plan.executeCollect().head.getLong(0) }
可以看到,count方法调用的是withAction方法,入参有三个:字符串count、调用方法获取到的QueryExecution、一个函数。注:此处就是对Scala函数式编程的应用,将函数作为参数来传递。
2、第二个参数QueryExecution的获取流程
2.1、首先看groupBy()方法:
1 def groupBy(cols: Column*): RelationalGroupedDataset = { 2 RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) 3 }
groupBy方法是用于分组聚合的,一般用法是groupBy之后加上agg聚合函数,对分组之后的每组数据进行聚合,入参为Column类型的可变长度参数。
但上面count方法中调用时未传任何入参,产生的效果就是****
groupBy方法只有一行代码,生成并返回了一个RelationalGroupedDataset的对象,而且此处是用伴生对象的简略写法创建出来的,该行代码其实质是调用了RelationalGroupedDataset的伴生对象中的apply方法,三个入参。
注:RelationalGroupedDataset 类是用于处理聚合操作的,内部封装了对agg方法的处理,以及一些统计函数sum、max等的实现。
2.1.1、逐一看下RelationalGroupedDataset的三个入参:
首先是toDF()方法,方法体如下,可见就是重新创建了一个Dataset[Row]对象,即DataFrame
def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema))
然后是cols.map(_.expr),即遍历执行每个Column的expr表达式,因为此处未传入cols,故可忽略。
最后传入的是 RelationalGroupedDataset.GroupByType,起了标识的作用。因为RelationalGroupedDataset类的方法除了groupBy调用之外,还有Cube、Rollup、Pivot等都会调用,为与其他几种区别开,故传入了GroupByType。
2.1.2、初探 RelationalGroupedDataset 类
apply方法:
1 def apply( 2 df: DataFrame, 3 groupingExprs: Seq[Expression], 4 groupType: GroupType): RelationalGroupedDataset = { 5 new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType) 6 }
类的定义:
class RelationalGroupedDataset protected[sql]( private[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { ...... }
可见没有多余的逻辑,只是单纯的创建了一个对象。至于这个对象如何使用的,还需继续追溯它里面的count方法,即Dataset.count()中调用的groupBy().count()。
2.2、groupBy().count(),即 RelationalGroupedDataset.count():
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))
2.2.1、其中Alias(Count(Literal(1)).toAggregateExpression(), "count")的作用,就是生成 count(1) as count 这样的一个统计函数的表达式。
2.2.2、然后toDF方法,如下所示:
1 private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { 2 val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { // 是否保留分组的主键列,默认true 3 groupingExprs match { // 若保留,则将分组的主键列拼到聚合表达式的前面 4 // call `toList` because `Stream` can't serialize in scala 2.13 5 case s: Stream[Expression] => s.toList ++ aggExprs 6 case other => other ++ aggExprs 7 } 8 } else { 9 aggExprs 10 } 11 12 val aliasedAgg = aggregates.map(alias) // 处理设置别名的表达式 13 14 groupType match { 15 case RelationalGroupedDataset.GroupByType => 16 Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) // *** 17 case RelationalGroupedDataset.RollupType => 18 Dataset.ofRows( 19 df.sparkSession, Aggregate(Seq(Rollup(groupingExprs.map(Seq(_)))), 20 aliasedAgg, df.logicalPlan)) 21 case RelationalGroupedDataset.CubeType => 22 Dataset.ofRows( 23 df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))), 24 aliasedAgg, df.logicalPlan)) 25 case RelationalGroupedDataset.PivotType(pivotCol, values) => 26 val aliasedGrps = groupingExprs.map(alias) 27 Dataset.ofRows( 28 df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) 29 } 30 }
重点是第16行,进入ofRows方法中可以看到,其实就是又新建了一个Dataset[Row],并将加上count(1)表达式之后新生成的Aggregate执行计划传入。
至此,groupBy().count().queryExecution得到的就是一个count(1)的执行计划了。
3、第三个参数,也是一个函数式参数:
{ plan => plan.executeCollect().head.getLong(0) }
该参数入参是一个plan,返回值long类型,推测是获取最终count值的,暂时放一放,后面调用到的时候再来研究。
4、看完三个参数,下面进入withAction方法:
1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { 2 SQLExecution.withNewExecutionId(qe, Some(name)) { 3 qe.executedPlan.resetMetrics() 4 action(qe.executedPlan) 5 } 6 }
又是使用了科里化传参,第三个参数同样是一个函数,在里面调用了action这个函数参数。继续追踪withNewExecutionId方法:
5、SQLExecution.withNewExecutionId
该方法代码较多,下面先看一下它的主体结构。里面省略的若干行代码,实际是作为一个函数参数传入了withActive方法。
def withNewExecutionId[T]( queryExecution: QueryExecution, name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive { ... // 省略若干代码 }
而withActive方法如下,实际是将当前的SparkSession存入了本地线程变量中,方便后面的获取。然后执行了函数block,而返回值就是外层withNewExecutionId方法中函数体的返回值。
private[sql] def withActive[T](block: => T): T = { val old = SparkSession.activeThreadSession.get() SparkSession.setActiveSession(this) try block finally { SparkSession.setActiveSession(old) } }
下面回到外层的函数体:
5.1、SQLExecution.withNewExecutionId函数体第一部分
1 val sparkSession = queryExecution.sparkSession 2 val sc = sparkSession.sparkContext 3 val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) 4 val executionId = SQLExecution.nextExecutionId 5 sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) 6 executionIdToQueryExecution.put(executionId, queryExecution)
先设置了一下executionId,该ID是一个线程安全的自增序列,每次加1,。设置给SparkContext之后,又将id与QueryExecution的映射关系存入Map中。
5.2、SQLExecution.withNewExecutionId函数体第二部分
第二部分主要是判断若sql长度过长,需要进行截断处理,无甚要点。
5.3、SQLExecution.withNewExecutionId函数体第三部分,代码如下:
1 withSQLConfPropagated(sparkSession) { 2 var ex: Option[Throwable] = None 3 val startTime = System.nanoTime() 4 try { 5 sc.listenerBus.post(SparkListenerSQLExecutionStart( 6 executionId = executionId, 7 description = desc, 8 details = callSite.longForm, 9 physicalPlanDescription = queryExecution.explainString(planDescriptionMode), 10 sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), 11 time = System.currentTimeMillis())) 12 body 13 } catch { 14 case e: Throwable => 15 ex = Some(e) 16 throw e 17 } finally { 18 val endTime = System.nanoTime() 19 val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) 20 event.executionName = name 21 event.duration = endTime - startTime 22 event.qe = queryExecution 23 event.executionFailure = ex 24 sc.listenerBus.post(event) 25 } 26 }
起头的 withSQLConfPropagated 方法,同样还是科里化的方式传参,方法里面将配置参数替换为新的配置参数,执行完之后再将老参数存回去。
再然后是try里面的一个post方法,finally里面一个post方法,用于发送SQLExecution执行开始和结束的通知消息。
最后是核心函数调用,body。即前面一直引而未看的方法。
下面再返回头来好好研究一下此处的body函数,函数体是:
{
qe.executedPlan.resetMetrics()
action(qe.executedPlan)
}
qe变量即上面2.2中返回的groupBy().count().queryExecution
而action的函数体是:
{ plan => plan.executeCollect().head.getLong(0) }
那么内部具体是怎么实现的呢?今天时间不早了,改日再搞它。