Spark SQL(8)-Spark sql聚合操作(Aggregation)
Spark SQL(8)-Spark sql聚合操作(Aggregation)
之前简单总结了spark从sql到物理计划的整个流程,接下来就总结下Spark SQL中关于聚合的操作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | private def withQuerySpecification( ctx : QuerySpecificationContext, relation : LogicalPlan) : LogicalPlan = withOrigin(ctx) { import ctx. _ // 去掉了一些其他操作代码。。。 // Add where. val withFilter = withLateralView.optionalMap(where)(filter) // Add aggregation or a project. val namedExpressions = { case e : NamedExpression = > e case e : Expression = > UnresolvedAlias(e) } val withProject = if (aggregation ! = null ) { withAggregation(aggregation, namedExpressions, withFilter) } else if (namedExpressions.nonEmpty) { Project(namedExpressions, withFilter) } else { withFilter } // Having val withHaving = withProject.optional(having) { // Note that we add a cast to non-predicate expressions. If the expression itself is // already boolean, the optimizer will get rid of the unnecessary cast. val predicate = expression(having) match { case p : Predicate = > p case e = > Cast(e, BooleanType) } Filter(predicate, withProject) } // Distinct val withDistinct = if (setQuantifier() ! = null && setQuantifier().DISTINCT() ! = null ) { Distinct(withHaving) } else { withHaving } // Window // Hint } } |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | private def withAggregation( ctx : AggregationContext, selectExpressions : Seq[NamedExpression], query : LogicalPlan) : LogicalPlan = withOrigin(ctx) { val groupByExpressions = expressionList(ctx.groupingExpressions) if (ctx.GROUPING ! = null ) { // GROUP BY .... GROUPING SETS (...) val selectedGroupByExprs = _ = > expression(e))) GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) } else { // GROUP BY .... (WITH CUBE | WITH ROLLUP)? val mappedGroupByExpressions = if (ctx.CUBE ! = null ) { Seq(Cube(groupByExpressions)) } else if (ctx.ROLLUP ! = null ) { Seq(Rollup(groupByExpressions)) } else { groupByExpressions } Aggregate(mappedGroupByExpressions, selectExpressions, query) } } |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | object Aggregation extends Strategy { def apply(plan : LogicalPlan) : Seq[SparkPlan] = plan match { case PhysicalAggregation( groupingExpressions, aggregateExpressions, resultExpressions, child) = > val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition( _ .isDistinct) if ( _ .aggregateFunction.children.toSet).distinct.length > 1 ) { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error( "You hit a query analyzer bug. Please report your query to " + "Spark user mailing list." ) } val aggregateOperator = if (functionsWithDistinct.isEmpty) { aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, resultExpressions, planLater(child)) } aggregateOperator case _ = > Nil } } |
接下来会分析下planAggregateWithoutDistinct和planAggregateWithOneDistinct实现的不同,还有就是spark sql针对聚合操作的实现方式。在介绍这俩个之前首先介绍下聚合的模式(AggregateMode)
Partial 主要是代表局部合并,对输入的数据更新到聚合缓冲区,返回聚合缓冲区数据;
Final 将聚合缓冲区的数据进行合并,返回最终的结果;
Complete 不能进行局部合并,直接计算返回最终的结果;
PartialMerge 对聚合缓冲区的数据进行合并,其主要用于distinct语句中,返回的依然是聚合缓冲区数据。
1、DeclarativeAggregate 声明式的聚合函数
2、ImperativeAggregate 指令式的聚合函数
3、TypedImperativeAggregate是ImperativeAggregate的子类,他可以用java 对象存储在内存缓冲区中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | def planAggregateWithoutDistinct( groupingExpressions : Seq[NamedExpression], aggregateExpressions : Seq[AggregateExpression], resultExpressions : Seq[NamedExpression], child : SparkPlan) : Seq[SparkPlan] = { // Check if we can use HashAggregate. // 1. Create an Aggregate Operator for partial aggregations. val groupingAttributes = _ .toAttribute) val partialAggregateExpressions = _ .copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap( _ .aggregateFunction.aggBufferAttributes) val partialResultExpressions = groupingAttributes ++ partialAggregateExpressions.flatMap( _ .aggregateFunction.inputAggBufferAttributes) val partialAggregate = createAggregate( requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, aggregateExpressions = partialAggregateExpressions, aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0 , resultExpressions = partialResultExpressions, child = child) // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = _ .copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: val finalAggregateAttributes = _ .resultAttribute) val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, aggregateExpressions = finalAggregateExpressions, aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) finalAggregate :: Nil } |
3、创建一个partial计划 这一步用于distinct
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | private def createAggregate( requiredChildDistributionExpressions : Option[Seq[Expression]] = None, groupingExpressions : Seq[NamedExpression] = Nil, aggregateExpressions : Seq[AggregateExpression] = Nil, aggregateAttributes : Seq[Attribute] = Nil, initialInputBufferOffset : Int = 0 , resultExpressions : Seq[NamedExpression] = Nil, child : SparkPlan) : SparkPlan = { val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap( _ .aggregateFunction.aggBufferAttributes)) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, child = child) } else { val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions) if (objectHashEnabled && useObjectHash) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, child = child) } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, child = child) } } } |
从上面的逻辑可以看出来;如果可以进行hashAggregate操作则选取hashAggregate; 他的具体条件是聚合的schema都在下面这些里面就可以采用hashAggregate
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | static { mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( Arrays.asList( new DataType[] { NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DateType, TimestampType }))); } |
1、hashMap = new UnsafeFixedWidthAggregationMap;这个map里面保存了分组的key和其对应的聚合缓冲数据;在UnsafeFixedWidthAggregationMap里面,重要的成员变量有map = BytesToBytesMap 实际保存的数据就在这个map里面。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | private def processInputs(fallbackStartsAt : (Int, Int)) : Unit = { if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. val groupingKey = groupingProjection.apply( null ) val buffer : UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = processRow(buffer, newInput) } } else { var i = 0 while (inputIter.hasNext) { val newInput = val groupingKey = groupingProjection.apply(newInput) var buffer : UnsafeRow = null if (i < fallbackStartsAt. _ 2 ) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) } if (buffer == null ) { val sorter = hashMap.destructAndCreateExternalSorter() if (externalSorter == null ) { externalSorter = sorter } else { externalSorter.merge(sorter) } i = 0 buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null ) { // failed to allocate the first page throw new SparkOutOfMemoryError( "No enough memory for aggregation" ) } } processRow(buffer, newInput) i + = 1 } if (externalSorter ! = null ) { val sorter = hashMap.destructAndCreateExternalSorter() externalSorter.merge(sorter) switchToSortBasedAggregation() } } } |
这里的溢写操作会new UnsafeKVExternalSorter 并返回保存到externalSorter中,如果是初次那么直接赋值,如果不是那么就进行merge;这里会把hashmap里面的map就是bytesbybtesMap的数据会传进去,之后创建UnsafeInMemorySorter,将bytesbybtesMap的数据导入到UnsafeInMemorySorter里面;在之后调用UnsafeExternalSorter.createWithExistingInMemorySorter,对数据进行排序溢写
externalSorter = UnsafeKVExternalSorter;在switchToSortBasedAggregation里面,externalSorter首先会调用UnsafeKVExternalSorter.sortedIterator方法拿到排序后的record迭代器,之后调用其next就行,这里的next的值就是sortedInputHasNewGroup的值,用于表示是否还有值(这里只是首次,相当于初始化这个变量的值)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | override final def hasNext : Boolean = { (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) } override final def next() : UnsafeRow = { if (hasNext) { val res = if (sortBased) { // Process the current group. processCurrentSortedGroup() // Generate output row for the current group. val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) // Initialize buffer values for the next group. sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) outputRow } else { // We did not fall back to sort-based aggregation. val result = generateOutput( aggregationBufferMapIterator.getKey, aggregationBufferMapIterator.getValue) // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext // idempotent. mapIteratorHasNext = if (!mapIteratorHasNext) { // If there is no input from aggregationBufferMapIterator, we copy current result. val resultCopy = result.copy() // Then, we free the map. resultCopy } else { result } } numOutputRows + = 1 res } else { // no more result throw new NoSuchElementException } } |
如果是基于排序的聚合next方法会查看sortedInputHasNewGroup;这个值在初始化的时候直接调用的是基于外排的kv存储(UnsafeKVExternalSorter)的next; 之后在取值的时候,主要的逻辑就是processCurrentSortedGroup方法里面;
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup() : Unit = { // First, we need to copy nextGroupingKey to currentGroupingKey. currentGroupingKey.copyFrom(nextGroupingKey) // Now, we will start to find all rows belonging to this group. // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. // Pre-load the first key-value pair to make the condition of the while loop // has no action (we do not trigger loading a new key-value pair // when we evaluate the condition). var hasNext = while (!findNextPartition && hasNext) { // Get the grouping key and value (aggregation buffer). val groupingKey = sortedKVIterator.getKey val inputAggregationBuffer = sortedKVIterator.getValue // Check if the current row belongs the current input row. if (currentGroupingKey.equals(groupingKey)) { sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) hasNext = } else { // We find a new group. findNextPartition = true // copyFrom will fail when nextGroupingKey.copyFrom(groupingKey) firstRowInNextGroup.copyFrom(inputAggregationBuffer) } } |
到此基于hash的聚合计算整体流程算是结束了,这里面有几个比较重要的点;第一个就是hash的缓存是UnsafeFixedWidthAggregationMap在基于BytesToBytesMap(spark自己实现的hashmap)实现的,第二个就是hash的溢写最后是在UnsafeExternalSorter里面溢写UnsafeInMemorySorter里面的数据实现的,第三个就是externalSorter = UnsafeKVExternalSorter 基于排序的聚合其实是依赖UnsafeKVExternalSorter(依赖UnsafeExternalSorter)实现;第四个就是排序的中间缓存数据的计算以及最后结果输出时的处理。
这个类似于hashAggregate,主要的不同就是它是针对TypedImperativeAggregate这种类型的聚合函数来的,他主要是可以将java object缓存在内存中,参与聚合的计算;这里面的聚合缓冲区的定义是 aggBufferIterator = Iterator[AggregationBufferEntry];他的溢写操作不同于hashAggregate-在计算中多次溢写,它是溢写一次就会退化到基于排序的聚合。大体的逻辑和hashAggragate的差不多。
1 2 3 | override def requiredChildOrdering : Seq[Seq[SortOrder]] = { _ , Ascending)) :: Nil } |
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
· C++代码改造为UTF-8编码问题的总结
· 实操Deepseek接入个人知识库
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· 【.NET】调用本地 Deepseek 模型
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 上周热点回顾(2.17-2.23)