Spark SQL(8)-Spark sql聚合操作(Aggregation)

Spark SQL(8)-Spark sql聚合操作(Aggregation)

之前简单总结了spark从sql到物理计划的整个流程,接下来就总结下Spark SQL中关于聚合的操作。

聚合操作的物理计划生成

首先从一条sql开始吧

1
SELECT NAME,COUNT(*) FRON PEOPLE GROUP BY NAME

  这条sql的经过antlr4解析后的树结构如下:

 

 

 在解析出来的树结构中可以看出来,在querySpecification下面多了aggregation子节点。这次我们只关注关于聚合的相关操作。在analyze的阶段,关于聚合的解析是在AstBuilder.withQuerySpecification方法中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  private def withQuerySpecification(
      ctx: QuerySpecificationContext,
      relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
    import ctx._
// 去掉了一些其他操作代码。。。
        // Add where.
        val withFilter = withLateralView.optionalMap(where)(filter)
 
        // Add aggregation or a project.
        val namedExpressions = expressions.map {
          case e: NamedExpression => e
          case e: Expression => UnresolvedAlias(e)
        }
        val withProject = if (aggregation != null) {
          withAggregation(aggregation, namedExpressions, withFilter)
        } else if (namedExpressions.nonEmpty) {
          Project(namedExpressions, withFilter)
        } else {
          withFilter
        }
 
        // Having
        val withHaving = withProject.optional(having) {
          // Note that we add a cast to non-predicate expressions. If the expression itself is
          // already boolean, the optimizer will get rid of the unnecessary cast.
          val predicate = expression(having) match {
            case p: Predicate => p
            case e => Cast(e, BooleanType)
          }
          Filter(predicate, withProject)
        }
 
        // Distinct
        val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
          Distinct(withHaving)
        } else {
          withHaving
        }
 
        // Window
         
        // Hint
         
    }
  }

  如下为withAggregation方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
private def withAggregation(
     ctx: AggregationContext,
     selectExpressions: Seq[NamedExpression],
     query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
   val groupByExpressions = expressionList(ctx.groupingExpressions)
 
   if (ctx.GROUPING != null) {
     // GROUP BY .... GROUPING SETS (...)
     val selectedGroupByExprs =
       ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)))
     GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions)
   } else {
     // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
     val mappedGroupByExpressions = if (ctx.CUBE != null) {
       Seq(Cube(groupByExpressions))
     } else if (ctx.ROLLUP != null) {
       Seq(Rollup(groupByExpressions))
     } else {
       groupByExpressions
     }
     Aggregate(mappedGroupByExpressions, selectExpressions, query)
   }
 }

  可以看出来最后在树中添加了一个Aggregate节点。现在这里跳过优化的操作就是物理计划的处理,物理计划里面主要关注聚合相关的策略就是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
object Aggregation extends Strategy {
   def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
     case PhysicalAggregation(
         groupingExpressions, aggregateExpressions, resultExpressions, child) =>
 
       val (functionsWithDistinct, functionsWithoutDistinct) =
         aggregateExpressions.partition(_.isDistinct)
       if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
         // This is a sanity check. We should not reach here when we have multiple distinct
         // column sets. Our `RewriteDistinctAggregates` should take care this case.
         sys.error("You hit a query analyzer bug. Please report your query to " +
             "Spark user mailing list.")
       }
 
       val aggregateOperator =
         if (functionsWithDistinct.isEmpty) {
           aggregate.AggUtils.planAggregateWithoutDistinct(
             groupingExpressions,
             aggregateExpressions,
             resultExpressions,
             planLater(child))
         } else {
           aggregate.AggUtils.planAggregateWithOneDistinct(
             groupingExpressions,
             functionsWithDistinct,
             functionsWithoutDistinct,
             resultExpressions,
             planLater(child))
         }
 
       aggregateOperator
 
     case _ => Nil
   }
 }

  从上面的逻辑可以看出来,这里根据函数里面有没有包含distinct操作,分别调用planAggregateWithoutDistinct和planAggregateWithOneDistinct来生成物理计划。到此再经过准备阶段,聚合操作的物理计划为生成也就结束了。

接下来会分析下planAggregateWithoutDistinct和planAggregateWithOneDistinct实现的不同,还有就是spark sql针对聚合操作的实现方式。在介绍这俩个之前首先介绍下聚合的模式(AggregateMode)

聚合的模式AggregateMode和聚合函数

Partial 主要是代表局部合并,对输入的数据更新到聚合缓冲区,返回聚合缓冲区数据;

Final 将聚合缓冲区的数据进行合并,返回最终的结果;

Complete 不能进行局部合并,直接计算返回最终的结果;

PartialMerge 对聚合缓冲区的数据进行合并,其主要用于distinct语句中,返回的依然是聚合缓冲区数据。

接下来顺便介绍下聚合函数分类:

1、DeclarativeAggregate 声明式的聚合函数

2、ImperativeAggregate 指令式的聚合函数

3、TypedImperativeAggregate是ImperativeAggregate的子类,他可以用java 对象存储在内存缓冲区中。

声明的聚合函数和指令式的聚合函数的不同主要体现在update、merge操作上,DeclarativeAggregate对这俩个操作主要是重写表达式的形式来体现;ImperativeAggregate则要重写其方法。

接下来介绍下planAggregateWithoutDistinct和planAggregateWithOneDistinct的不同:

关于planAggregateWithoutDistinct:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def planAggregateWithoutDistinct(
     groupingExpressions: Seq[NamedExpression],
     aggregateExpressions: Seq[AggregateExpression],
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan): Seq[SparkPlan] = {
   // Check if we can use HashAggregate.
 
   // 1. Create an Aggregate Operator for partial aggregations.
 
   val groupingAttributes = groupingExpressions.map(_.toAttribute)
   val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
   val partialAggregateAttributes =
     partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
   val partialResultExpressions =
     groupingAttributes ++
       partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
 
   val partialAggregate = createAggregate(
       requiredChildDistributionExpressions = None,
       groupingExpressions = groupingExpressions,
       aggregateExpressions = partialAggregateExpressions,
       aggregateAttributes = partialAggregateAttributes,
       initialInputBufferOffset = 0,
       resultExpressions = partialResultExpressions,
       child = child)
 
   // 2. Create an Aggregate Operator for final aggregations.
   val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
   // The attributes of the final aggregation buffer, which is presented as input to the result
   // projection:
   val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
 
   val finalAggregate = createAggregate(
       requiredChildDistributionExpressions = Some(groupingAttributes),
       groupingExpressions = groupingAttributes,
       aggregateExpressions = finalAggregateExpressions,
       aggregateAttributes = finalAggregateAttributes,
       initialInputBufferOffset = groupingExpressions.length,
       resultExpressions = resultExpressions,
       child = partialAggregate)
 
   finalAggregate :: Nil
 }

  上面的方法其实可以总结成俩步,第一步就是创建一个聚合计划用于局部合并阶段,第二步就是创建一个final聚合计算。

关于planAggregateWithOneDistinct:

这个其实和上面的planAggregateWithoutDistinct差不太多,只不过是变成了四步:

1、创建一个聚合计划用于局部合并阶段

2、创建partialMerge计划;

3、创建一个partial计划 这一步用于distinct

4、创建一个final计划

 

在这俩个方法里面都用到了createAggregate在这个方法里面确定了到底使用何种方式来实现聚合计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
private def createAggregate(
    requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
    groupingExpressions: Seq[NamedExpression] = Nil,
    aggregateExpressions: Seq[AggregateExpression] = Nil,
    aggregateAttributes: Seq[Attribute] = Nil,
    initialInputBufferOffset: Int = 0,
    resultExpressions: Seq[NamedExpression] = Nil,
    child: SparkPlan): SparkPlan = {
  val useHash = HashAggregateExec.supportsAggregate(
    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
  if (useHash) {
    HashAggregateExec(
      requiredChildDistributionExpressions = requiredChildDistributionExpressions,
      groupingExpressions = groupingExpressions,
      aggregateExpressions = aggregateExpressions,
      aggregateAttributes = aggregateAttributes,
      initialInputBufferOffset = initialInputBufferOffset,
      resultExpressions = resultExpressions,
      child = child)
  } else {
    val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
    val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
 
    if (objectHashEnabled && useObjectHash) {
      ObjectHashAggregateExec(
        requiredChildDistributionExpressions = requiredChildDistributionExpressions,
        groupingExpressions = groupingExpressions,
        aggregateExpressions = aggregateExpressions,
        aggregateAttributes = aggregateAttributes,
        initialInputBufferOffset = initialInputBufferOffset,
        resultExpressions = resultExpressions,
        child = child)
    } else {
      SortAggregateExec(
        requiredChildDistributionExpressions = requiredChildDistributionExpressions,
        groupingExpressions = groupingExpressions,
        aggregateExpressions = aggregateExpressions,
        aggregateAttributes = aggregateAttributes,
        initialInputBufferOffset = initialInputBufferOffset,
        resultExpressions = resultExpressions,
        child = child)
    }
  }
}

  从上面的逻辑可以看出来;如果可以进行hashAggregate操作则选取hashAggregate; 他的具体条件是聚合的schema都在下面这些里面就可以采用hashAggregate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
static {
    mutableFieldTypes = Collections.unmodifiableSet(
      new HashSet<>(
        Arrays.asList(new DataType[] {
          NullType,
          BooleanType,
          ByteType,
          ShortType,
          IntegerType,
          LongType,
          FloatType,
          DoubleType,
          DateType,
          TimestampType
        })));
  }

  之后如果打开了objectHash的开关,并且聚合的函数表达式是TypedImperativeAggregate那么就采用objectHash;然后如何前面俩个都不满足那么就选择sortAggregate聚合的方式。下面会介绍下这三种聚合方式。

HashAggregateExec介绍

hashAggregate的逻辑主要是构建一个hashmap,以分组为key,将数据保存在这个map中进行聚合计算,这个map维护在内存中,如果内存不足的情况下,会进行溢写的操作,之后hashaggregate会退化为基于排序的聚合操作。

在doExecute方法中实例化一个TungstenAggregationIterator,在这个类里面实现了聚合的操作:

1、hashMap = new UnsafeFixedWidthAggregationMap;这个map里面保存了分组的key和其对应的聚合缓冲数据;在UnsafeFixedWidthAggregationMap里面,重要的成员变量有map = BytesToBytesMap 实际保存的数据就在这个map里面。

2、主要逻辑在processInputs方法里面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
private def processInputs(fallbackStartsAt: (Int, Int)): Unit = {
   if (groupingExpressions.isEmpty) {
     // If there is no grouping expressions, we can just reuse the same buffer over and over again.
     // Note that it would be better to eliminate the hash map entirely in the future.
     val groupingKey = groupingProjection.apply(null)
     val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
     while (inputIter.hasNext) {
       val newInput = inputIter.next()
       processRow(buffer, newInput)
     }
   } else {
     var i = 0
     while (inputIter.hasNext) {
       val newInput = inputIter.next()
       val groupingKey = groupingProjection.apply(newInput)
       var buffer: UnsafeRow = null
       if (i < fallbackStartsAt._2) {
         buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
       }
       if (buffer == null) {
         val sorter = hashMap.destructAndCreateExternalSorter()
         if (externalSorter == null) {
           externalSorter = sorter
         } else {
           externalSorter.merge(sorter)
         }
         i = 0
         buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
         if (buffer == null) {
           // failed to allocate the first page
           throw new SparkOutOfMemoryError("No enough memory for aggregation")
         }
       }
       processRow(buffer, newInput)
       i += 1
     }
 
     if (externalSorter != null) {
       val sorter = hashMap.destructAndCreateExternalSorter()
       externalSorter.merge(sorter)
       hashMap.free()
 
       switchToSortBasedAggregation()
     }
   }
 }

  这个方法里面的逻辑大体就是:从inputIter里面获取数据,然后根据聚合缓冲区的数据对数据进行新增或者更新的操作;在调用hashMap.getAggregationBufferFromUnsafeRow(groupingKey);如果返回的数据为null,那么表示内存不足,这个时候就会进行溢写操作:

       这里的溢写操作会new UnsafeKVExternalSorter 并返回保存到externalSorter中,如果是初次那么直接赋值,如果不是那么就进行merge;这里会把hashmap里面的map就是bytesbybtesMap的数据会传进去,之后创建UnsafeInMemorySorter,将bytesbybtesMap的数据导入到UnsafeInMemorySorter里面;在之后调用UnsafeExternalSorter.createWithExistingInMemorySorter,对数据进行排序溢写

       溢写结束之后会重置bytesbybtesMap然后hashmap继续申请内存继续计算,如果内存不足继续溢写;直到inputIter没有元素;

       接着会根据externalSorter是否为null来判断需不需要切换到基于排序聚合操作。

       如果不切换基于排序的聚合;则会给aggregationBufferMapIterator和mapIteratorHasNext赋值;

       如果切换到基于排序的聚合;那么会调用switchToSortBasedAggregation;初始化一些基于排序的变量;之后会用于next和hasNext方法中:

       基于排序的聚合需要的变量有:

       externalSorter = UnsafeKVExternalSorter;在switchToSortBasedAggregation里面,externalSorter首先会调用UnsafeKVExternalSorter.sortedIterator方法拿到排序后的record迭代器,之后调用其next就行,这里的next的值就是sortedInputHasNewGroup的值,用于表示是否还有值(这里只是首次,相当于初始化这个变量的值)。

       

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
override final def hasNext: Boolean = {
    (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
  }
 
  override final def next(): UnsafeRow = {
    if (hasNext) {
      val res = if (sortBased) {
        // Process the current group.
        processCurrentSortedGroup()
        // Generate output row for the current group.
        val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
        // Initialize buffer values for the next group.
        sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
 
        outputRow
      } else {
        // We did not fall back to sort-based aggregation.
        val result =
          generateOutput(
            aggregationBufferMapIterator.getKey,
            aggregationBufferMapIterator.getValue)
 
        // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext
        // idempotent.
        mapIteratorHasNext = aggregationBufferMapIterator.next()
 
        if (!mapIteratorHasNext) {
          // If there is no input from aggregationBufferMapIterator, we copy current result.
          val resultCopy = result.copy()
          // Then, we free the map.
          hashMap.free()
 
          resultCopy
        } else {
          result
        }
      }
 
      numOutputRows += 1
      res
    } else {
      // no more result
      throw new NoSuchElementException
    }
  }

  下面看下next和hasnext的实现:如果是基于hash的聚合的hasnext就直接判断mapitertor里面是否还有元素;如果有那么直接从保存的hashmap里面获取key和value来组装输出以这样的方式实现next的;

      如果是基于排序的聚合next方法会查看sortedInputHasNewGroup;这个值在初始化的时候直接调用的是基于外排的kv存储(UnsafeKVExternalSorter)的next; 之后在取值的时候,主要的逻辑就是processCurrentSortedGroup方法里面;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// Processes rows in the current group. It will stop when it find a new group.
 private def processCurrentSortedGroup(): Unit = {
   // First, we need to copy nextGroupingKey to currentGroupingKey.
   currentGroupingKey.copyFrom(nextGroupingKey)
   // Now, we will start to find all rows belonging to this group.
   // We create a variable to track if we see the next group.
   var findNextPartition = false
   // firstRowInNextGroup is the first row of this group. We first process it.
   sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup)
 
   // The search will stop when we see the next group or there is no
   // input row left in the iter.
   // Pre-load the first key-value pair to make the condition of the while loop
   // has no action (we do not trigger loading a new key-value pair
   // when we evaluate the condition).
   var hasNext = sortedKVIterator.next()
   while (!findNextPartition && hasNext) {
     // Get the grouping key and value (aggregation buffer).
     val groupingKey = sortedKVIterator.getKey
     val inputAggregationBuffer = sortedKVIterator.getValue
 
     // Check if the current row belongs the current input row.
     if (currentGroupingKey.equals(groupingKey)) {
       sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer)
 
       hasNext = sortedKVIterator.next()
     } else {
       // We find a new group.
       findNextPartition = true
       // copyFrom will fail when
       nextGroupingKey.copyFrom(groupingKey)
       firstRowInNextGroup.copyFrom(inputAggregationBuffer)
     }
   }

  在这个方法里面大体的逻辑就是从sortedKVIterator的迭代器里面取数据,因为数据是基于key排序的,如果key相同那么就继续取值聚合计算;如果不相同那么就是遇到新值了,这个时候把计算的聚合结果和key返回,当成一次next的返回;

     到此这个就是大体的hash聚合的整体流程了;这里面还有一个就是基于排序的对中间数据的聚合计算其实调用的是generateProcessRow;这个方法其实就是基于当前的聚合模式和聚合的函数来决定如何计算聚合函数的值;如果是声明式的调用updateExpressions或者mergeExpressions、如果是指令式的就调用对应的update和merge方法,这里的计算相当于只是更新聚合缓冲区的数据;

     在此之后返回结果的方法:generateResultProjection,这个方法里面也会根据聚合模式和聚合的函数判断来决定如何计算;返回的是UnsafeProjection。(这里的描述比较乱,跳跃比较大,需要跟着源码理解,不然长篇大论更加看不懂)。

      到此基于hash的聚合计算整体流程算是结束了,这里面有几个比较重要的点;第一个就是hash的缓存是UnsafeFixedWidthAggregationMap在基于BytesToBytesMap(spark自己实现的hashmap)实现的,第二个就是hash的溢写最后是在UnsafeExternalSorter里面溢写UnsafeInMemorySorter里面的数据实现的,第三个就是externalSorter = UnsafeKVExternalSorter 基于排序的聚合其实是依赖UnsafeKVExternalSorter(依赖UnsafeExternalSorter)实现;第四个就是排序的中间缓存数据的计算以及最后结果输出时的处理。

     ObjectHashAggregate

     这个类似于hashAggregate,主要的不同就是它是针对TypedImperativeAggregate这种类型的聚合函数来的,他主要是可以将java object缓存在内存中,参与聚合的计算;这里面的聚合缓冲区的定义是 aggBufferIterator = Iterator[AggregationBufferEntry];他的溢写操作不同于hashAggregate-在计算中多次溢写,它是溢写一次就会退化到基于排序的聚合。大体的逻辑和hashAggragate的差不多。

     SortAggregateExec

      基于排序的聚合操作的原理就是数据根据key进行排序,之后顺序读取数据,如果key相同那么就进行聚合函数的计算,如果不同那么代表遇到了新的key;那么就重新计算新的聚合结果。

这里的实现和在hashAggregate里面的实现大同小异,主要的思想没有变,就连计算中间的聚合函数结果的方法都是用的同一个;这里有一个需要注意的点就是:

1
2
3
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
  groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}

 这里对子节点的排序做了要求,所以在准备阶段的话,在sortAggregate之前会增加排序的操作,感兴趣的同学可以参考hashAggregate中对基于排序的聚合计算的描述来理解这里的基于排序的聚合计算过程。

     到此整个聚合计算的过程已经总结完毕。中间还有很多可以展开的东西,但是这里只是总结聚合的操作,其他的可以在后续单独总结。

posted @   刘姥爷观园子  阅读(2797)  评论(0编辑  收藏  举报
编辑推荐:
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
· C++代码改造为UTF-8编码问题的总结
阅读排行:
· 实操Deepseek接入个人知识库
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· 【.NET】调用本地 Deepseek 模型
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 上周热点回顾(2.17-2.23)
点击右上角即可分享
微信分享提示