Spark SQL(9)-Spark SQL JOIN操作源码总结
Spark SQL(9)-Spark SQL JOIN操作源码总结
本文主要总结下spark sql join操作的实现,本文会根据spark sql 的源码来总结其具体的实现;大体流程还是从sql语句到逻辑算子树再到analyzed-> optimized -> 物理计划及其处理逻辑进行大致的总结。
Join逻辑算子树
先来一个sql:
SELECT NAME FROM NAME LEFT JOIN NAME2 ON NAME = NAME JOIN NAME3 ON NAME = NAME
这条sql形成的逻辑算子树为:
上图的树结构的生成;主要关注join部分就可以;其源码在AstBuilder中:
1 2 3 4 5 6 7 8 | override def visitFromClause(ctx : FromClauseContext) : LogicalPlan = withOrigin(ctx) { val from = ctx.relation.asScala.foldLeft( null : LogicalPlan) { (left, relation) = > val right = plan(relation.relationPrimary) val join = right.optionalMap(left)(Join( _ , _ , Inner, None)) withJoinRelations(join, relation) } ctx.lateralView.asScala.foldLeft(from)(withGenerate) } |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | private def withJoinRelations(base : LogicalPlan, ctx : RelationContext) : LogicalPlan = { val pp = ctx.joinRelation pp.asScala.foldLeft(base) { (left, join) = > withOrigin(join) { val baseJoinType = join.joinType match { case null = > Inner case jt if jt.CROSS ! = null = > Cross case jt if jt.FULL ! = null = > FullOuter case jt if jt.SEMI ! = null = > LeftSemi case jt if jt.ANTI ! = null = > LeftAnti case jt if jt.LEFT ! = null = > LeftOuter case jt if jt.RIGHT ! = null = > RightOuter case _ = > Inner } // Resolve the join type and join condition val (joinType, condition) = Option(join.joinCriteria) match { case Some(c) if c.USING ! = null = > (UsingJoin(baseJoinType, c.identifier.asScala.map( _ .getText)), None) case Some(c) if c.booleanExpression ! = null = > (baseJoinType, Option(expression(c.booleanExpression))) case None if join.NATURAL ! = null = > if (baseJoinType == Cross) { throw new ParseException( "NATURAL CROSS JOIN is not supported" , ctx) } (NaturalJoin(baseJoinType), None) case None = > (baseJoinType, None) } Join(left, plan(join.right), joinType, condition) } } } |
从上图可以看出来对于join的操作,形成的树结构里面,保存的join关系是一个list<JoinReleation>,每个joinRelation包含了JoinType、relationPrimary以及joinCriteria;其中joinCriteria相当于是booleanExpression操作。
之后就是Join Analyzed 以及optimized 操作,在这里俩步主要操作就是添加子查询别名等操作,之后在优化阶段算子下推、消除子查询别名等优化;这里面涉及的规则比较多,感兴趣的同学可以查看源码多研究研究;
物理计划阶段
这一步主要涉及到 SparkPlanner 中配置的各种strategies,在这些策略中主要关注JoinSelection部分就行,他的apply方如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | def apply(plan : LogicalPlan) : Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- // broadcast hints were specified case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBroadcastByHints(joinType, left, right) = > val buildSide = broadcastSideByHints(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // broadcast hints were not specified, so need to infer it from size and configuration. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBroadcastBySizes(joinType, left, right) = > val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) = > Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) = > Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) = > joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast case j @ logical.Join(left, right, joinType, condition) if canBroadcastByHints(joinType, left, right) = > val buildSide = broadcastSideByHints(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case j @ logical.Join(left, right, joinType, condition) if canBroadcastBySizes(joinType, left, right) = > val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin case logical.Join(left, right, _: InnerLike, condition) = > joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) = > val buildSide = broadcastSide( left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- case _ = > Nil } } |
从上面的代码可以看出其根据不同的条件生成不同的join操作:BroadcastHashJoinExec、ShuffledHashJoinExec、SortMergeJoinExec、BroadcastNestedLoopJoinExec;
在介绍在四个操作之前,先介绍下join操作实现的大体思想:
假设有俩张表,在spark中进行操作的时候;
一张表为流表;一张表为构建表;默认的大表为流表,小表为构建表;基于流表的迭代,然后和构建表进行匹配,生成join之后的行数据。其实可以想象一种极端情况;大表特别的大有几百万行数据,小表数据只有10行,这个时候只需要迭代遍历流表,然后去小表(构建表)去匹配数据,匹配到之后生成join完成之后的行;
在spark中join的大体实现是分流表和构建表;基于这俩个角色来实现join操作。接下来简单介绍下上面的几种join操作:
1、BroadcastHashJoinExec主要通过广播形式实现join操作;其生成的条件是:一种是标记了hint;并且可以创建构建右表或者构建左表;另外一种是小表小于配置的spark.sql.autoBroadcastJoinThreshold参数的大小,则会进行基于广播的join;这里面spark会先将构建表的数据拉倒driver端,之后再分发到各个worker节点,所以这一步如果构建表比较大的情况下对spark的driver节点来说可能会有压力。
2、ShuffledHashJoinExec 通过shuffle之后在内存中保存join构建表来实现join操作;其生成的条件是:可以构建左表或者右表,其次表的大小小于分区数和配置的广播参数的乘积(保证可以加载到本地内存进行计算),并且打开了优先考虑基于hash join的开关、其次需要保证构建表足够小(构建表*3小于流表);其主要思想就是对流表进行迭代,之后和内存中的构建表数据匹配生成join之后的行数据。
3、SortMergeJoinExec 通过shuffle操作之后进行排序,再然后进行基于排序的join操作;如果上述俩个都不满足的情况就会进行就排序的join(前提是可以排序);排序的join就是先对数据进行shuffle分区,保证相同的key分到相同的分区,之后进行排序操作,保证数据有序,之后进行merge join操作,同时读取流表和构建表;因为数据有序,所以只要顺序遍历流表和构建表;匹配生成join行数据就行
4、BroadcastNestedLoopJoinExec 主要针对的是没有join条件的连接操作;暂时不做研究;
接下来主要总结下hashJoin和SortMergeJoinExec的实现逻辑;
ShuffledHashJoinExec
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | private def buildHashedRelation(iter : Iterator[InternalRow]) : HashedRelation = { val buildDataSize = longMetric( "buildDataSize" ) val buildTime = longMetric( "buildTime" ) val start = System.nanoTime() val context = TaskContext.get() val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) buildTime + = (System.nanoTime() - start) / 1000000 buildDataSize + = relation.estimatedSize // This relation is usually used until the end of task. context.addTaskCompletionListener( _ = > relation.close()) relation } protected override def doExecute() : RDD[InternalRow] = { val numOutputRows = longMetric( "numOutputRows" ) val avgHashProbe = longMetric( "avgHashProbe" ) streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) = > val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows, avgHashProbe) } } |
先看上面的doExecute方法,一般物理计划都是触发这个方法来执行的,这里主要的逻辑是调用了buildHashedRelation方法,在这个方法中主要关注HashedRelation就行:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | private [execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. */ def apply( input : Iterator[InternalRow], key : Seq[Expression], sizeEstimate : Int = 64 , taskMemoryManager : TaskMemoryManager = null ) : HashedRelation = { val mm = Option(taskMemoryManager).getOrElse { new TaskMemoryManager( new StaticMemoryManager( new SparkConf().set(MEMORY _ OFFHEAP _ ENABLED.key, "false" ), Long.MaxValue, Long.MaxValue, 1 ), 0 ) } if (key.length == 1 && key.head.dataType == LongType) { LongHashedRelation(input, key, sizeEstimate, mm) } else { UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } |
这里面根据类型dataType如果是long那么就生成LongHashedRelation(基于LongToUnsafeRowMap实现),如果不是就是UnsafeHashedRelation(基于BytesToBytesMap实现)这里主要关注UnsafeHashedRelation就行:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | private [joins] object UnsafeHashedRelation { def apply( input : Iterator[InternalRow], key : Seq[Expression], sizeEstimate : Int, taskMemoryManager : TaskMemoryManager) : HashedRelation = { val pageSizeBytes = Option(SparkEnv.get).map( _ .memoryManager.pageSizeBytes) .getOrElse( new SparkConf().getSizeAsBytes( "spark.buffer.pageSize" , "16m" )) val binaryMap = new BytesToBytesMap( taskMemoryManager, // Only 70% of the slots can be used before growing, more capacity help to reduce collision (sizeEstimate * 1.5 + 1 ).toInt, pageSizeBytes, true ) // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() val key = keyGenerator(row) if (!key.anyNull) { val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) if (!success) { binaryMap.free() throw new SparkException( "There is no enough memory to build hash map" ) } } } new UnsafeHashedRelation(numFields, binaryMap) } |
从上面的代码可以看出,这里主要是根据从ShuffledHashJoinExec传过来的buildKeys,构建一个基于buildKeys和rows的映射表,其实就是上面提到的构建表。这里准备好构建表之后,回到上面提到的ShuffledHashJoinExec.doExecute中可以看到:
1 2 3 4 5 6 7 8 | protected override def doExecute() : RDD[InternalRow] = { val numOutputRows = longMetric( "numOutputRows" ) val avgHashProbe = longMetric( "avgHashProbe" ) streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) = > val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows, avgHashProbe) } } |
可以看到基于streamIter(流表)、hashed(构建表)构成了一个join操作:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | protected def join( streamedIter : Iterator[InternalRow], hashed : HashedRelation, numOutputRows : SQLMetric, avgHashProbe : SQLMetric) : Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike = > innerJoin(streamedIter, hashed) case LeftOuter | RightOuter = > outerJoin(streamedIter, hashed) case LeftSemi = > semiJoin(streamedIter, hashed) case LeftAnti = > antiJoin(streamedIter, hashed) case j : ExistenceJoin = > existenceJoin(streamedIter, hashed) case x = > throw new IllegalArgumentException( s "BroadcastHashJoin should not take $x as the JoinType" ) } // At the end of the task, we update the avg hash probe. TaskContext.get().addTaskCompletionListener( _ = > avgHashProbe.set(hashed.getAverageProbesPerLookup)) val resultProj = createResultProjection joinedIter.map { r = > numOutputRows + = 1 resultProj(r) } } |
这里可以看看innerJoin的操作:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | private def innerJoin( streamIter : Iterator[InternalRow], hashedRelation : HashedRelation) : Iterator[InternalRow] = { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() streamIter.flatMap { srow = > joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches ! = null ) { matches.map(joinRow.withRight( _ )).filter(boundCondition) } else { Seq.empty } } } |
可以看出,遍历流表,从构建表获取相同的key,如果不为空就构建joinRow,并应用join的条件进行筛选。到这里整个hash join的实现就算是完成了。对于其他类型的join可以自己跟代码阅读。
SortMergeJoinExec
doExecute方法如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | protected override def doExecute() : RDD[InternalRow] = { val numOutputRows = longMetric( "numOutputRows" ) val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) = > val boundCondition : (InternalRow) = > Boolean = { condition.map { cond = > newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r : InternalRow) = > true } } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map( _ .dataType)) val resultProj : InternalRow = > InternalRow = UnsafeProjection.create(output, output) joinType match { case _: InnerLike = > new RowIterator { private [ this ] var currentLeftRow : InternalRow = _ private [ this ] var currentRightMatches : ExternalAppendOnlyUnsafeRowArray = _ private [ this ] var rightMatchesIterator : Iterator[UnsafeRow] = null private [ this ] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private [ this ] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow rightMatchesIterator = currentRightMatches.generateIterator() } override def advanceNext() : Boolean = { while (rightMatchesIterator ! = null ) { if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null rightMatchesIterator = null return false } } joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows + = 1 return true } } false } override def getRow : InternalRow = resultProj(joinRow) }.toScala case LeftOuter = > val smjScanner = new SortMergeJoinScanner( streamedKeyGenerator = createLeftKeyGenerator(), bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala case RightOuter = > val smjScanner = new SortMergeJoinScanner( streamedKeyGenerator = createRightKeyGenerator(), bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala case FullOuter = > val leftNullRow = new GenericInternalRow(left.output.length) val rightNullRow = new GenericInternalRow(right.output.length) val smjScanner = new SortMergeFullOuterJoinScanner( leftKeyGenerator = createLeftKeyGenerator(), rightKeyGenerator = createRightKeyGenerator(), keyOrdering, leftIter = RowIterator.fromScala(leftIter), rightIter = RowIterator.fromScala(rightIter), boundCondition, leftNullRow, rightNullRow) new FullOuterIterator( smjScanner, resultProj, numOutputRows).toScala case LeftSemi = > new RowIterator { private [ this ] var currentLeftRow : InternalRow = _ private [ this ] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private [ this ] val joinRow = new JoinedRow override def advanceNext() : Boolean = { while (smjScanner.findNextInnerJoinRows()) { val currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow if (currentRightMatches ! = null && currentRightMatches.length > 0 ) { val rightMatchesIterator = currentRightMatches.generateIterator() while (rightMatchesIterator.hasNext) { joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows + = 1 return true } } } } false } override def getRow : InternalRow = currentLeftRow }.toScala case LeftAnti = > new RowIterator { private [ this ] var currentLeftRow : InternalRow = _ private [ this ] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private [ this ] val joinRow = new JoinedRow override def advanceNext() : Boolean = { while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches if (currentRightMatches == null || currentRightMatches.length == 0 ) { numOutputRows + = 1 return true } var found = false val rightMatchesIterator = currentRightMatches.generateIterator() while (!found && rightMatchesIterator.hasNext) { joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } } if (!found) { numOutputRows + = 1 return true } } false } override def getRow : InternalRow = currentLeftRow }.toScala case j : ExistenceJoin = > new RowIterator { private [ this ] var currentLeftRow : InternalRow = _ private [ this ] val result : InternalRow = new GenericInternalRow(Array[Any]( null )) private [ this ] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold ) private [ this ] val joinRow = new JoinedRow override def advanceNext() : Boolean = { while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches var found = false if (currentRightMatches ! = null && currentRightMatches.length > 0 ) { val rightMatchesIterator = currentRightMatches.generateIterator() while (!found && rightMatchesIterator.hasNext) { joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } } } result.setBoolean( 0 , found) numOutputRows + = 1 return true } false } override def getRow : InternalRow = resultProj(joinRow(currentLeftRow, result)) }.toScala case x = > throw new IllegalArgumentException( s "SortMergeJoin should not take $x as the JoinType" ) } } } |
这里首先看下InnerLike分支下的实现:
具体逻辑很简单:
实例化了一个SortMergeJoinScanner,具体实现可以看实现的advanceNext方法,调用findNextInnerJoinRows找到下一行可以join的数据;这里面:
1、currentLeftRow相当于是流表数据,触发是:smjScanner.getStreamedRow
2、currentRightMatches相当于是构建表数据,触发是:smjScanner.getBufferedMatches
3、advanceNext这里面主要就是findNextInnerJoinRows方法,如果返回true那么就是有新行,直接重置1、2的值,然后构建joinRow,之后再应用过滤条件
4、findNextInnerJoinRows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | final def findNextInnerJoinRows() : Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } if (streamedRow == null ) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else if (matchJoinKey ! = null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0 ) { // The new streamed row has the same join key as the previous row, so return the same matches. true } else if (bufferedRow == null ) { // The streamed row's join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() } else { assert(!bufferedRowKey.anyNull) comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) if (comp > 0 ) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0 ) advancedStreamed() } } while (streamedRow ! = null && bufferedRow ! = null && comp ! = 0 ) if (streamedRow == null || bufferedRow == null ) { // We have either hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // The streamed row's join key matches the current buffered row's join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0 ) bufferMatchingRows() true } } } |
主要逻辑如下:
如果流表为空直接返回,
如何流表的行可以和当前的缓存matchJoinKey对应上,则返回true;
如果构建表为空,直接返回false;
之后具体逻辑在do while中,首先还是校验;之后对流表和构建表数据的key进行比对,如果大于0;则重新拿构建表的数据,如果小于0,就拿流表的数据,如果不是就循环,直到俩个key相同,或者俩个表为空;之后会一直添加bufferedMatches(相当于对拥有同一个key的构建表数据进行append操作,加入bufferedMatches中);
其次在bufferMatchingRows方法中记录了matchJoinKey,之后再调用findNextInnerJoinRows的时候,如果发现新的流表key和matchJoinKey相同直接返回true,进行join操作。
关于LeftOuter和RightOuter主要实现是基于LeftOuterIterator和RightOuterIterator,这俩个是OneSideOuterIterator的具体实现,其实依赖SortMergeJoinScanner.findNextOuterJoinRows来判断流表和构建表的key,然后进行相应的处理;这俩个主要实现setBufferedSideOutput和setStreamSideOutput这俩个方法,之后的逻辑都在advanceStream中。
对于FullOuter主要实现就是FullOuterIterator,这里:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | private class FullOuterIterator( smjScanner : SortMergeFullOuterJoinScanner, resultProj : InternalRow = > InternalRow, numRows : SQLMetric) extends RowIterator { private [ this ] val joinedRow : JoinedRow = smjScanner.getJoinedRow() override def advanceNext() : Boolean = { val r = smjScanner.advanceNext() if (r) numRows + = 1 r } override def getRow : InternalRow = resultProj(joinedRow) } |
这么看FullOuter的实现倒是最简单的;
因为返回的是一个迭代器,所以在查看源码的时候,主要关注advanceNext方法的实现,根据这个可以追溯到整个的join的过程。
总结,这里主要简单总结了下spark join的实现思想。具体的实现细节还是要深入代码去了解,比如SortMergeJoinExec中,他的溢出是基于什么的?这个其实在SortMergeJoinScanner
中的ExternalAppendOnlyUnsafeRowArray,他基于UnsafeExternalSorter来实现对应的溢写操作。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
· C++代码改造为UTF-8编码问题的总结
· 实操Deepseek接入个人知识库
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· 【.NET】调用本地 Deepseek 模型
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 上周热点回顾(2.17-2.23)