Spark SQL(9)-Spark SQL JOIN操作源码总结

Spark SQL(9)-Spark SQL JOIN操作源码总结

本文主要总结下spark sql join操作的实现,本文会根据spark sql 的源码来总结其具体的实现;大体流程还是从sql语句到逻辑算子树再到analyzed-> optimized -> 物理计划及其处理逻辑进行大致的总结。

Join逻辑算子树

先来一个sql:

SELECT NAME FROM NAME LEFT 
JOIN NAME2 ON NAME = NAME
JOIN NAME3 ON NAME = NAME

这条sql形成的逻辑算子树为:

上图的树结构的生成;主要关注join部分就可以;其源码在AstBuilder中:

1
2
3
4
5
6
7
8
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
   val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
     val right = plan(relation.relationPrimary)
     val join = right.optionalMap(left)(Join(_, _, Inner, None))
     withJoinRelations(join, relation)
   }
   ctx.lateralView.asScala.foldLeft(from)(withGenerate)
 }

  

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
  val pp = ctx.joinRelation
  pp.asScala.foldLeft(base) { (left, join) =>
    withOrigin(join) {
      val baseJoinType = join.joinType match {
        case null => Inner
        case jt if jt.CROSS != null => Cross
        case jt if jt.FULL != null => FullOuter
        case jt if jt.SEMI != null => LeftSemi
        case jt if jt.ANTI != null => LeftAnti
        case jt if jt.LEFT != null => LeftOuter
        case jt if jt.RIGHT != null => RightOuter
        case _ => Inner
      }
 
      // Resolve the join type and join condition
      val (joinType, condition) = Option(join.joinCriteria) match {
        case Some(c) if c.USING != null =>
          (UsingJoin(baseJoinType, c.identifier.asScala.map(_.getText)), None)
        case Some(c) if c.booleanExpression != null =>
          (baseJoinType, Option(expression(c.booleanExpression)))
        case None if join.NATURAL != null =>
          if (baseJoinType == Cross) {
            throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
          }
          (NaturalJoin(baseJoinType), None)
        case None =>
          (baseJoinType, None)
      }
      Join(left, plan(join.right), joinType, condition)
    }
  }
}

 从上图可以看出来对于join的操作,形成的树结构里面,保存的join关系是一个list<JoinReleation>,每个joinRelation包含了JoinType、relationPrimary以及joinCriteria;其中joinCriteria相当于是booleanExpression操作。

   之后就是Join Analyzed 以及optimized 操作,在这里俩步主要操作就是添加子查询别名等操作,之后在优化阶段算子下推、消除子查询别名等优化;这里面涉及的规则比较多,感兴趣的同学可以查看源码多研究研究;

物理计划阶段

    这一步主要涉及到 SparkPlanner 中配置的各种strategies,在这些策略中主要关注JoinSelection部分就行,他的apply方如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
    // --- BroadcastHashJoin --------------------------------------------------------------------
 
    // broadcast hints were specified
    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
      if canBroadcastByHints(joinType, left, right) =>
      val buildSide = broadcastSideByHints(joinType, left, right)
      Seq(joins.BroadcastHashJoinExec(
        leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
 
    // broadcast hints were not specified, so need to infer it from size and configuration.
    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
      if canBroadcastBySizes(joinType, left, right) =>
      val buildSide = broadcastSideBySizes(joinType, left, right)
      Seq(joins.BroadcastHashJoinExec(
        leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
 
    // --- ShuffledHashJoin ---------------------------------------------------------------------
 
    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
       if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
         && muchSmaller(right, left) ||
         !RowOrdering.isOrderable(leftKeys) =>
      Seq(joins.ShuffledHashJoinExec(
        leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
 
    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
       if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
         && muchSmaller(left, right) ||
         !RowOrdering.isOrderable(leftKeys) =>
      Seq(joins.ShuffledHashJoinExec(
        leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
 
    // --- SortMergeJoin ------------------------------------------------------------
 
    case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
      if RowOrdering.isOrderable(leftKeys) =>
      joins.SortMergeJoinExec(
        leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
 
    // --- Without joining keys ------------------------------------------------------------
 
    // Pick BroadcastNestedLoopJoin if one side could be broadcast
    case j @ logical.Join(left, right, joinType, condition)
        if canBroadcastByHints(joinType, left, right) =>
      val buildSide = broadcastSideByHints(joinType, left, right)
      joins.BroadcastNestedLoopJoinExec(
        planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
 
    case j @ logical.Join(left, right, joinType, condition)
        if canBroadcastBySizes(joinType, left, right) =>
      val buildSide = broadcastSideBySizes(joinType, left, right)
      joins.BroadcastNestedLoopJoinExec(
        planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
 
    // Pick CartesianProduct for InnerJoin
    case logical.Join(left, right, _: InnerLike, condition) =>
      joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
 
    case logical.Join(left, right, joinType, condition) =>
      val buildSide = broadcastSide(
        left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
      // This join could be very slow or OOM
      joins.BroadcastNestedLoopJoinExec(
        planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
 
    // --- Cases where this strategy does not apply ---------------------------------------------
 
    case _ => Nil
  }
}

  从上面的代码可以看出其根据不同的条件生成不同的join操作:BroadcastHashJoinExec、ShuffledHashJoinExec、SortMergeJoinExec、BroadcastNestedLoopJoinExec;

      在介绍在四个操作之前,先介绍下join操作实现的大体思想:

      假设有俩张表,在spark中进行操作的时候;

      一张表为流表;一张表为构建表;默认的大表为流表,小表为构建表;基于流表的迭代,然后和构建表进行匹配,生成join之后的行数据。其实可以想象一种极端情况;大表特别的大有几百万行数据,小表数据只有10行,这个时候只需要迭代遍历流表,然后去小表(构建表)去匹配数据,匹配到之后生成join完成之后的行;

      在spark中join的大体实现是分流表和构建表;基于这俩个角色来实现join操作。接下来简单介绍下上面的几种join操作:

      1、BroadcastHashJoinExec主要通过广播形式实现join操作;其生成的条件是:一种是标记了hint;并且可以创建构建右表或者构建左表;另外一种是小表小于配置的spark.sql.autoBroadcastJoinThreshold参数的大小,则会进行基于广播的join;这里面spark会先将构建表的数据拉倒driver端,之后再分发到各个worker节点,所以这一步如果构建表比较大的情况下对spark的driver节点来说可能会有压力。

      2、ShuffledHashJoinExec 通过shuffle之后在内存中保存join构建表来实现join操作;其生成的条件是:可以构建左表或者右表,其次表的大小小于分区数和配置的广播参数的乘积(保证可以加载到本地内存进行计算),并且打开了优先考虑基于hash join的开关、其次需要保证构建表足够小(构建表*3小于流表);其主要思想就是对流表进行迭代,之后和内存中的构建表数据匹配生成join之后的行数据。

      3、SortMergeJoinExec 通过shuffle操作之后进行排序,再然后进行基于排序的join操作;如果上述俩个都不满足的情况就会进行就排序的join(前提是可以排序);排序的join就是先对数据进行shuffle分区,保证相同的key分到相同的分区,之后进行排序操作,保证数据有序,之后进行merge join操作,同时读取流表和构建表;因为数据有序,所以只要顺序遍历流表和构建表;匹配生成join行数据就行

      4、BroadcastNestedLoopJoinExec 主要针对的是没有join条件的连接操作;暂时不做研究;

接下来主要总结下hashJoin和SortMergeJoinExec的实现逻辑;

      ShuffledHashJoinExec

      

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
  val buildDataSize = longMetric("buildDataSize")
  val buildTime = longMetric("buildTime")
  val start = System.nanoTime()
  val context = TaskContext.get()
  val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
  buildTime += (System.nanoTime() - start) / 1000000
  buildDataSize += relation.estimatedSize
  // This relation is usually used until the end of task.
  context.addTaskCompletionListener(_ => relation.close())
  relation
}
 
protected override def doExecute(): RDD[InternalRow] = {
  val numOutputRows = longMetric("numOutputRows")
  val avgHashProbe = longMetric("avgHashProbe")
  streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
    val hashed = buildHashedRelation(buildIter)
    join(streamIter, hashed, numOutputRows, avgHashProbe)
  }
}

  先看上面的doExecute方法,一般物理计划都是触发这个方法来执行的,这里主要的逻辑是调用了buildHashedRelation方法,在这个方法中主要关注HashedRelation就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private[execution] object HashedRelation {
 
  /**
   * Create a HashedRelation from an Iterator of InternalRow.
   */
  def apply(
      input: Iterator[InternalRow],
      key: Seq[Expression],
      sizeEstimate: Int = 64,
      taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
    val mm = Option(taskMemoryManager).getOrElse {
      new TaskMemoryManager(
        new StaticMemoryManager(
          new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
          Long.MaxValue,
          Long.MaxValue,
          1),
        0)
    }
 
    if (key.length == 1 && key.head.dataType == LongType) {
      LongHashedRelation(input, key, sizeEstimate, mm)
    } else {
      UnsafeHashedRelation(input, key, sizeEstimate, mm)
    }
  }
}

  这里面根据类型dataType如果是long那么就生成LongHashedRelation(基于LongToUnsafeRowMap实现),如果不是就是UnsafeHashedRelation(基于BytesToBytesMap实现)这里主要关注UnsafeHashedRelation就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
private[joins] object UnsafeHashedRelation {
 
  def apply(
      input: Iterator[InternalRow],
      key: Seq[Expression],
      sizeEstimate: Int,
      taskMemoryManager: TaskMemoryManager): HashedRelation = {
 
    val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
      .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
 
    val binaryMap = new BytesToBytesMap(
      taskMemoryManager,
      // Only 70% of the slots can be used before growing, more capacity help to reduce collision
      (sizeEstimate * 1.5 + 1).toInt,
      pageSizeBytes,
      true)
 
    // Create a mapping of buildKeys -> rows
    val keyGenerator = UnsafeProjection.create(key)
    var numFields = 0
    while (input.hasNext) {
      val row = input.next().asInstanceOf[UnsafeRow]
      numFields = row.numFields()
      val key = keyGenerator(row)
      if (!key.anyNull) {
        val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
        val success = loc.append(
          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
          row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
        if (!success) {
          binaryMap.free()
          throw new SparkException("There is no enough memory to build hash map")
        }
      }
    }
 
    new UnsafeHashedRelation(numFields, binaryMap)
  }

  从上面的代码可以看出,这里主要是根据从ShuffledHashJoinExec传过来的buildKeys,构建一个基于buildKeys和rows的映射表,其实就是上面提到的构建表。这里准备好构建表之后,回到上面提到的ShuffledHashJoinExec.doExecute中可以看到:

1
2
3
4
5
6
7
8
protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    val avgHashProbe = longMetric("avgHashProbe")
    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
      val hashed = buildHashedRelation(buildIter)
      join(streamIter, hashed, numOutputRows, avgHashProbe)
    }
  }

  可以看到基于streamIter(流表)、hashed(构建表)构成了一个join操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
protected def join(
    streamedIter: Iterator[InternalRow],
    hashed: HashedRelation,
    numOutputRows: SQLMetric,
    avgHashProbe: SQLMetric): Iterator[InternalRow] = {
 
  val joinedIter = joinType match {
    case _: InnerLike =>
      innerJoin(streamedIter, hashed)
    case LeftOuter | RightOuter =>
      outerJoin(streamedIter, hashed)
    case LeftSemi =>
      semiJoin(streamedIter, hashed)
    case LeftAnti =>
      antiJoin(streamedIter, hashed)
    case j: ExistenceJoin =>
      existenceJoin(streamedIter, hashed)
    case x =>
      throw new IllegalArgumentException(
        s"BroadcastHashJoin should not take $x as the JoinType")
  }
 
  // At the end of the task, we update the avg hash probe.
  TaskContext.get().addTaskCompletionListener(_ =>
    avgHashProbe.set(hashed.getAverageProbesPerLookup))
 
  val resultProj = createResultProjection
  joinedIter.map { r =>
    numOutputRows += 1
    resultProj(r)
  }
}

  这里可以看看innerJoin的操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
private def innerJoin(
     streamIter: Iterator[InternalRow],
     hashedRelation: HashedRelation): Iterator[InternalRow] = {
   val joinRow = new JoinedRow
   val joinKeys = streamSideKeyGenerator()
   streamIter.flatMap { srow =>
     joinRow.withLeft(srow)
     val matches = hashedRelation.get(joinKeys(srow))
     if (matches != null) {
       matches.map(joinRow.withRight(_)).filter(boundCondition)
     } else {
       Seq.empty
     }
   }
 }

   可以看出,遍历流表,从构建表获取相同的key,如果不为空就构建joinRow,并应用join的条件进行筛选。到这里整个hash join的实现就算是完成了。对于其他类型的join可以自己跟代码阅读。

     SortMergeJoinExec

      doExecute方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
protected override def doExecute(): RDD[InternalRow] = {
   val numOutputRows = longMetric("numOutputRows")
   val spillThreshold = getSpillThreshold
   val inMemoryThreshold = getInMemoryThreshold
   left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
     val boundCondition: (InternalRow) => Boolean = {
       condition.map { cond =>
         newPredicate(cond, left.output ++ right.output).eval _
       }.getOrElse {
         (r: InternalRow) => true
       }
     }
 
     // An ordering that can be used to compare keys from both sides.
     val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
     val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
 
     joinType match {
       case _: InnerLike =>
         new RowIterator {
           private[this] var currentLeftRow: InternalRow = _
           private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
           private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
           private[this] val smjScanner = new SortMergeJoinScanner(
             createLeftKeyGenerator(),
             createRightKeyGenerator(),
             keyOrdering,
             RowIterator.fromScala(leftIter),
             RowIterator.fromScala(rightIter),
             inMemoryThreshold,
             spillThreshold
           )
           private[this] val joinRow = new JoinedRow
 
           if (smjScanner.findNextInnerJoinRows()) {
             currentRightMatches = smjScanner.getBufferedMatches
             currentLeftRow = smjScanner.getStreamedRow
             rightMatchesIterator = currentRightMatches.generateIterator()
           }
 
           override def advanceNext(): Boolean = {
             while (rightMatchesIterator != null) {
               if (!rightMatchesIterator.hasNext) {
                 if (smjScanner.findNextInnerJoinRows()) {
                   currentRightMatches = smjScanner.getBufferedMatches
                   currentLeftRow = smjScanner.getStreamedRow
                   rightMatchesIterator = currentRightMatches.generateIterator()
                 } else {
                   currentRightMatches = null
                   currentLeftRow = null
                   rightMatchesIterator = null
                   return false
                 }
               }
               joinRow(currentLeftRow, rightMatchesIterator.next())
               if (boundCondition(joinRow)) {
                 numOutputRows += 1
                 return true
               }
             }
             false
           }
 
           override def getRow: InternalRow = resultProj(joinRow)
         }.toScala
 
       case LeftOuter =>
         val smjScanner = new SortMergeJoinScanner(
           streamedKeyGenerator = createLeftKeyGenerator(),
           bufferedKeyGenerator = createRightKeyGenerator(),
           keyOrdering,
           streamedIter = RowIterator.fromScala(leftIter),
           bufferedIter = RowIterator.fromScala(rightIter),
           inMemoryThreshold,
           spillThreshold
         )
         val rightNullRow = new GenericInternalRow(right.output.length)
         new LeftOuterIterator(
           smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
 
       case RightOuter =>
         val smjScanner = new SortMergeJoinScanner(
           streamedKeyGenerator = createRightKeyGenerator(),
           bufferedKeyGenerator = createLeftKeyGenerator(),
           keyOrdering,
           streamedIter = RowIterator.fromScala(rightIter),
           bufferedIter = RowIterator.fromScala(leftIter),
           inMemoryThreshold,
           spillThreshold
         )
         val leftNullRow = new GenericInternalRow(left.output.length)
         new RightOuterIterator(
           smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
 
       case FullOuter =>
         val leftNullRow = new GenericInternalRow(left.output.length)
         val rightNullRow = new GenericInternalRow(right.output.length)
         val smjScanner = new SortMergeFullOuterJoinScanner(
           leftKeyGenerator = createLeftKeyGenerator(),
           rightKeyGenerator = createRightKeyGenerator(),
           keyOrdering,
           leftIter = RowIterator.fromScala(leftIter),
           rightIter = RowIterator.fromScala(rightIter),
           boundCondition,
           leftNullRow,
           rightNullRow)
 
         new FullOuterIterator(
           smjScanner,
           resultProj,
           numOutputRows).toScala
 
       case LeftSemi =>
         new RowIterator {
           private[this] var currentLeftRow: InternalRow = _
           private[this] val smjScanner = new SortMergeJoinScanner(
             createLeftKeyGenerator(),
             createRightKeyGenerator(),
             keyOrdering,
             RowIterator.fromScala(leftIter),
             RowIterator.fromScala(rightIter),
             inMemoryThreshold,
             spillThreshold
           )
           private[this] val joinRow = new JoinedRow
 
           override def advanceNext(): Boolean = {
             while (smjScanner.findNextInnerJoinRows()) {
               val currentRightMatches = smjScanner.getBufferedMatches
               currentLeftRow = smjScanner.getStreamedRow
               if (currentRightMatches != null && currentRightMatches.length > 0) {
                 val rightMatchesIterator = currentRightMatches.generateIterator()
                 while (rightMatchesIterator.hasNext) {
                   joinRow(currentLeftRow, rightMatchesIterator.next())
                   if (boundCondition(joinRow)) {
                     numOutputRows += 1
                     return true
                   }
                 }
               }
             }
             false
           }
 
           override def getRow: InternalRow = currentLeftRow
         }.toScala
 
       case LeftAnti =>
         new RowIterator {
           private[this] var currentLeftRow: InternalRow = _
           private[this] val smjScanner = new SortMergeJoinScanner(
             createLeftKeyGenerator(),
             createRightKeyGenerator(),
             keyOrdering,
             RowIterator.fromScala(leftIter),
             RowIterator.fromScala(rightIter),
             inMemoryThreshold,
             spillThreshold
           )
           private[this] val joinRow = new JoinedRow
 
           override def advanceNext(): Boolean = {
             while (smjScanner.findNextOuterJoinRows()) {
               currentLeftRow = smjScanner.getStreamedRow
               val currentRightMatches = smjScanner.getBufferedMatches
               if (currentRightMatches == null || currentRightMatches.length == 0) {
                 numOutputRows += 1
                 return true
               }
               var found = false
               val rightMatchesIterator = currentRightMatches.generateIterator()
               while (!found && rightMatchesIterator.hasNext) {
                 joinRow(currentLeftRow, rightMatchesIterator.next())
                 if (boundCondition(joinRow)) {
                   found = true
                 }
               }
               if (!found) {
                 numOutputRows += 1
                 return true
               }
             }
             false
           }
 
           override def getRow: InternalRow = currentLeftRow
         }.toScala
 
       case j: ExistenceJoin =>
         new RowIterator {
           private[this] var currentLeftRow: InternalRow = _
           private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null))
           private[this] val smjScanner = new SortMergeJoinScanner(
             createLeftKeyGenerator(),
             createRightKeyGenerator(),
             keyOrdering,
             RowIterator.fromScala(leftIter),
             RowIterator.fromScala(rightIter),
             inMemoryThreshold,
             spillThreshold
           )
           private[this] val joinRow = new JoinedRow
 
           override def advanceNext(): Boolean = {
             while (smjScanner.findNextOuterJoinRows()) {
               currentLeftRow = smjScanner.getStreamedRow
               val currentRightMatches = smjScanner.getBufferedMatches
               var found = false
               if (currentRightMatches != null && currentRightMatches.length > 0) {
                 val rightMatchesIterator = currentRightMatches.generateIterator()
                 while (!found && rightMatchesIterator.hasNext) {
                   joinRow(currentLeftRow, rightMatchesIterator.next())
                   if (boundCondition(joinRow)) {
                     found = true
                   }
                 }
               }
               result.setBoolean(0, found)
               numOutputRows += 1
               return true
             }
             false
           }
 
           override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
         }.toScala
 
       case x =>
         throw new IllegalArgumentException(
           s"SortMergeJoin should not take $x as the JoinType")
     }
 
   }
 }

      这里首先看下InnerLike分支下的实现:

           具体逻辑很简单:

           实例化了一个SortMergeJoinScanner,具体实现可以看实现的advanceNext方法,调用findNextInnerJoinRows找到下一行可以join的数据;这里面:

           1、currentLeftRow相当于是流表数据,触发是:smjScanner.getStreamedRow

           2、currentRightMatches相当于是构建表数据,触发是:smjScanner.getBufferedMatches

           3、advanceNext这里面主要就是findNextInnerJoinRows方法,如果返回true那么就是有新行,直接重置1、2的值,然后构建joinRow,之后再应用过滤条件

           4、findNextInnerJoinRows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
final def findNextInnerJoinRows(): Boolean = {
   while (advancedStreamed() && streamedRowKey.anyNull) {
     // Advance the streamed side of the join until we find the next row whose join key contains
     // no nulls or we hit the end of the streamed iterator.
   }
   if (streamedRow == null) {
     // We have consumed the entire streamed iterator, so there can be no more matches.
     matchJoinKey = null
     bufferedMatches.clear()
     false
   } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
     // The new streamed row has the same join key as the previous row, so return the same matches.
     true
   } else if (bufferedRow == null) {
     // The streamed row's join key does not match the current batch of buffered rows and there are
     // no more rows to read from the buffered iterator, so there can be no more matches.
     matchJoinKey = null
     bufferedMatches.clear()
     false
   } else {
     // Advance both the streamed and buffered iterators to find the next pair of matching rows.
     var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
     do {
       if (streamedRowKey.anyNull) {
         advancedStreamed()
       } else {
         assert(!bufferedRowKey.anyNull)
         comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
         if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
         else if (comp < 0) advancedStreamed()
       }
     } while (streamedRow != null && bufferedRow != null && comp != 0)
     if (streamedRow == null || bufferedRow == null) {
       // We have either hit the end of one of the iterators, so there can be no more matches.
       matchJoinKey = null
       bufferedMatches.clear()
       false
     } else {
       // The streamed row's join key matches the current buffered row's join, so walk through the
       // buffered iterator to buffer the rest of the matching rows.
       assert(comp == 0)
       bufferMatchingRows()
       true
     }
   }
 }

  主要逻辑如下:

      如果流表为空直接返回,

      如何流表的行可以和当前的缓存matchJoinKey对应上,则返回true;

      如果构建表为空,直接返回false;

      之后具体逻辑在do while中,首先还是校验;之后对流表和构建表数据的key进行比对,如果大于0;则重新拿构建表的数据,如果小于0,就拿流表的数据,如果不是就循环,直到俩个key相同,或者俩个表为空;之后会一直添加bufferedMatches(相当于对拥有同一个key的构建表数据进行append操作,加入bufferedMatches中);

      其次在bufferMatchingRows方法中记录了matchJoinKey,之后再调用findNextInnerJoinRows的时候,如果发现新的流表key和matchJoinKey相同直接返回true,进行join操作。

      关于LeftOuter和RightOuter主要实现是基于LeftOuterIterator和RightOuterIterator,这俩个是OneSideOuterIterator的具体实现,其实依赖SortMergeJoinScanner.findNextOuterJoinRows来判断流表和构建表的key,然后进行相应的处理;这俩个主要实现setBufferedSideOutput和setStreamSideOutput这俩个方法,之后的逻辑都在advanceStream中。

      对于FullOuter主要实现就是FullOuterIterator,这里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private class FullOuterIterator(
    smjScanner: SortMergeFullOuterJoinScanner,
    resultProj: InternalRow => InternalRow,
    numRows: SQLMetric) extends RowIterator {
  private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
 
  override def advanceNext(): Boolean = {
    val r = smjScanner.advanceNext()
    if (r) numRows += 1
    r
  }
 
  override def getRow: InternalRow = resultProj(joinedRow)
}

  这么看FullOuter的实现倒是最简单的;

      因为返回的是一个迭代器,所以在查看源码的时候,主要关注advanceNext方法的实现,根据这个可以追溯到整个的join的过程。

      总结,这里主要简单总结了下spark join的实现思想。具体的实现细节还是要深入代码去了解,比如SortMergeJoinExec中,他的溢出是基于什么的?这个其实在SortMergeJoinScanner

中的ExternalAppendOnlyUnsafeRowArray,他基于UnsafeExternalSorter来实现对应的溢写操作。

 

     

      

  

posted @   刘姥爷观园子  阅读(1277)  评论(0编辑  收藏  举报
编辑推荐:
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· .NET Core 托管堆内存泄露/CPU异常的常见思路
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
· C++代码改造为UTF-8编码问题的总结
阅读排行:
· 实操Deepseek接入个人知识库
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· 【.NET】调用本地 Deepseek 模型
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 上周热点回顾(2.17-2.23)
点击右上角即可分享
微信分享提示