Spark记录(四):Dataset.count()方法源码剖析

因最近工作中涉及较多的Spark相关功能,所以趁周末闲来无事,研读一下Dataset的count方法。Spark版本3.2.0

1、方法入口:

  def count(): Long = withAction("count", groupBy().count().queryExecution) { plan =>
    plan.executeCollect().head.getLong(0)
  }

可以看到,count方法调用的是withAction方法,入参有三个:字符串count、调用方法获取到的QueryExecution、一个函数。注:此处就是对Scala函数式编程的应用,将函数作为参数来传递

2、第二个参数QueryExecution的获取流程

 2.1、首先看groupBy()方法:

1   def groupBy(cols: Column*): RelationalGroupedDataset = {
2     RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
3   }

groupBy方法是用于分组聚合的,一般用法是groupBy之后加上agg聚合函数,对分组之后的每组数据进行聚合,入参为Column类型的可变长度参数。

但上面count方法中调用时未传任何入参,产生的效果就是****

groupBy方法只有一行代码,生成并返回了一个RelationalGroupedDataset的对象,而且此处是用伴生对象的简略写法创建出来的,该行代码其实质是调用了RelationalGroupedDataset的伴生对象中的apply方法,三个入参。

注:RelationalGroupedDataset 类是用于处理聚合操作的,内部封装了对agg方法的处理,以及一些统计函数sum、max等的实现。

2.1.1、逐一看下RelationalGroupedDataset的三个入参:

首先是toDF()方法,方法体如下,可见就是重新创建了一个Dataset[Row]对象,即DataFrame

  def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema))

然后是cols.map(_.expr),即遍历执行每个Column的expr表达式,因为此处未传入cols,故可忽略。

最后传入的是 RelationalGroupedDataset.GroupByType,起了标识的作用。因为RelationalGroupedDataset类的方法除了groupBy调用之外,还有Cube、Rollup、Pivot等都会调用,为与其他几种区别开,故传入了GroupByType。

2.1.2、初探 RelationalGroupedDataset 类

apply方法:

1   def apply(
2       df: DataFrame,
3       groupingExprs: Seq[Expression],
4       groupType: GroupType): RelationalGroupedDataset = {
5     new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType)
6   }

类的定义:

class RelationalGroupedDataset protected[sql](
    private[sql] val df: DataFrame,
    private[sql] val groupingExprs: Seq[Expression],
    groupType: RelationalGroupedDataset.GroupType) {
......
}

可见没有多余的逻辑,只是单纯的创建了一个对象。至于这个对象如何使用的,还需继续追溯它里面的count方法,即Dataset.count()中调用的groupBy().count()。

2.2、groupBy().count(),即 RelationalGroupedDataset.count():

  def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))

2.2.1、其中Alias(Count(Literal(1)).toAggregateExpression(), "count")的作用,就是生成 count(1) as count 这样的一个统计函数的表达式。

2.2.2、然后toDF方法,如下所示:

 1 private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
 2     val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { // 是否保留分组的主键列,默认true
 3       groupingExprs match { // 若保留,则将分组的主键列拼到聚合表达式的前面
 4         // call `toList` because `Stream` can't serialize in scala 2.13
 5         case s: Stream[Expression] => s.toList ++ aggExprs
 6         case other => other ++ aggExprs
 7       }
 8     } else {
 9       aggExprs
10     }
11 
12     val aliasedAgg = aggregates.map(alias) // 处理设置别名的表达式
13 
14     groupType match {
15       case RelationalGroupedDataset.GroupByType =>
16         Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) // ***
17       case RelationalGroupedDataset.RollupType =>
18         Dataset.ofRows(
19           df.sparkSession, Aggregate(Seq(Rollup(groupingExprs.map(Seq(_)))),
20             aliasedAgg, df.logicalPlan))
21       case RelationalGroupedDataset.CubeType =>
22         Dataset.ofRows(
23           df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))),
24             aliasedAgg, df.logicalPlan))
25       case RelationalGroupedDataset.PivotType(pivotCol, values) =>
26         val aliasedGrps = groupingExprs.map(alias)
27         Dataset.ofRows(
28           df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan))
29     }
30   }

重点是第16行,进入ofRows方法中可以看到,其实就是又新建了一个Dataset[Row],并将加上count(1)表达式之后新生成的Aggregate执行计划传入。

至此,groupBy().count().queryExecution得到的就是一个count(1)的执行计划了。

3、第三个参数,也是一个函数式参数:

{ plan =>
    plan.executeCollect().head.getLong(0)
  }

该参数入参是一个plan,返回值long类型,推测是获取最终count值的,暂时放一放,后面调用到的时候再来研究。

4、看完三个参数,下面进入withAction方法:

1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
2     SQLExecution.withNewExecutionId(qe, Some(name)) {
3       qe.executedPlan.resetMetrics()
4       action(qe.executedPlan)
5     }
6   }

又是使用了科里化传参,第三个参数同样是一个函数,在里面调用了action这个函数参数。继续追踪withNewExecutionId方法:

 5、SQLExecution.withNewExecutionId

该方法代码较多,下面先看一下它的主体结构。里面省略的若干行代码,实际是作为一个函数参数传入了withActive方法。

def withNewExecutionId[T](
      queryExecution: QueryExecution,
      name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive {
... // 省略若干代码
}

而withActive方法如下,实际是将当前的SparkSession存入了本地线程变量中,方便后面的获取。然后执行了函数block,而返回值就是外层withNewExecutionId方法中函数体的返回值。

private[sql] def withActive[T](block: => T): T = {
    val old = SparkSession.activeThreadSession.get()
    SparkSession.setActiveSession(this)
    try block finally {
      SparkSession.setActiveSession(old)
    }
  }

 下面回到外层的函数体:

5.1、SQLExecution.withNewExecutionId函数体第一部分

1     val sparkSession = queryExecution.sparkSession
2     val sc = sparkSession.sparkContext
3     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
4     val executionId = SQLExecution.nextExecutionId
5     sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
6     executionIdToQueryExecution.put(executionId, queryExecution)

先设置了一下executionId,该ID是一个线程安全的自增序列,每次加1,。设置给SparkContext之后,又将id与QueryExecution的映射关系存入Map中。

5.2、SQLExecution.withNewExecutionId函数体第二部分

第二部分主要是判断若sql长度过长,需要进行截断处理,无甚要点。

5.3、SQLExecution.withNewExecutionId函数体第三部分,代码如下:

 1       withSQLConfPropagated(sparkSession) {
 2         var ex: Option[Throwable] = None
 3         val startTime = System.nanoTime()
 4         try {
 5           sc.listenerBus.post(SparkListenerSQLExecutionStart(
 6             executionId = executionId,
 7             description = desc,
 8             details = callSite.longForm,
 9             physicalPlanDescription = queryExecution.explainString(planDescriptionMode),
10             sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
11             time = System.currentTimeMillis()))
12           body
13         } catch {
14           case e: Throwable =>
15             ex = Some(e)
16             throw e
17         } finally {
18           val endTime = System.nanoTime()
19           val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
20           event.executionName = name
21           event.duration = endTime - startTime
22           event.qe = queryExecution
23           event.executionFailure = ex
24           sc.listenerBus.post(event)
25         }
26       }

起头的 withSQLConfPropagated 方法,同样还是科里化的方式传参,方法里面将配置参数替换为新的配置参数,执行完之后再将老参数存回去。

再然后是try里面的一个post方法,finally里面一个post方法,用于发送SQLExecution执行开始和结束的通知消息。

最后是核心函数调用,body。即前面一直引而未看的方法。

下面再返回头来好好研究一下此处的body函数,函数体是:

{
      qe.executedPlan.resetMetrics()
      action(qe.executedPlan)
    }

qe变量即上面2.2中返回的groupBy().count().queryExecution

而action的函数体是:

{ plan =>
    plan.executeCollect().head.getLong(0)
  }

那么内部具体是怎么实现的呢?今天时间不早了,改日再搞它。

posted on 2022-05-30 00:06  淡墨痕  阅读(759)  评论(0编辑  收藏  举报