-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35350][SQL] Add code-gen for left semi sort merge join #32528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -354,7 +354,7 @@ case class SortMergeJoinExec( | |
| } | ||
|
|
||
| private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match { | ||
| case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys)) | ||
| case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys)) | ||
| case RightOuter => ((right, rightKeys), (left, leftKeys)) | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
|
|
@@ -365,7 +365,7 @@ case class SortMergeJoinExec( | |
| private lazy val bufferedOutput = bufferedPlan.output | ||
|
|
||
| override def supportCodegen: Boolean = joinType match { | ||
| case _: InnerLike | LeftOuter | RightOuter => true | ||
| case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true | ||
| case _ => false | ||
| } | ||
|
|
||
|
|
@@ -424,8 +424,18 @@ case class SortMergeJoinExec( | |
| // A list to hold all matched rows from buffered side. | ||
| val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName | ||
|
|
||
| // Flag to only buffer first matched row, to avoid buffering unnecessary rows. | ||
| val onlyBufferFirstMatchedRow = (joinType, condition) match { | ||
| case (LeftSemi, None) => true | ||
| case _ => false | ||
| } | ||
| val inMemoryThreshold = | ||
| if (onlyBufferFirstMatchedRow) { | ||
| 1 | ||
| } else { | ||
| getInMemoryThreshold | ||
| } | ||
| val spillThreshold = getSpillThreshold | ||
| val inMemoryThreshold = getInMemoryThreshold | ||
|
|
||
| // Inline mutable state since not many join operations in a task | ||
| val matches = ctx.addMutableState(clsName, "matches", | ||
|
|
@@ -435,7 +445,7 @@ case class SortMergeJoinExec( | |
|
|
||
| // Handle the case when streamed rows has any NULL keys. | ||
| val handleStreamedAnyNull = joinType match { | ||
| case _: InnerLike => | ||
| case _: InnerLike | LeftSemi => | ||
| // Skip streamed row. | ||
| s""" | ||
| |$streamedRow = null; | ||
|
|
@@ -457,7 +467,7 @@ case class SortMergeJoinExec( | |
|
|
||
| // Handle the case when streamed keys has no match with buffered side. | ||
| val handleStreamedWithoutMatch = joinType match { | ||
| case _: InnerLike => | ||
| case _: InnerLike | LeftSemi => | ||
| // Skip streamed row. | ||
| s"$streamedRow = null;" | ||
| case LeftOuter | RightOuter => | ||
|
|
@@ -468,6 +478,17 @@ case class SortMergeJoinExec( | |
| s"SortMergeJoin.genScanner should not take $x as the JoinType") | ||
| } | ||
|
|
||
| val addRowToBuffer = | ||
| if (onlyBufferFirstMatchedRow) { | ||
| s""" | ||
| |if ($matches.isEmpty()) { | ||
| | $matches.add((UnsafeRow) $bufferedRow); | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| s"$matches.add((UnsafeRow) $bufferedRow);" | ||
| } | ||
|
|
||
| // Generate a function to scan both streamed and buffered sides to find a match. | ||
| // Return whether a match is found. | ||
| // | ||
|
|
@@ -483,17 +504,18 @@ case class SortMergeJoinExec( | |
| // The function has the following step: | ||
| // - Step 1: Find the next `streamedRow` with non-null join keys. | ||
| // For `streamedRow` with null join keys (`handleStreamedAnyNull`): | ||
| // 1. Inner join: skip the row. `matches` will be cleared later when hitting the | ||
| // next `streamedRow` with non-null join keys. | ||
| // 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when | ||
| // hitting the next `streamedRow` with non-null join | ||
| // keys. | ||
| // 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row, | ||
| // and return false. | ||
| // | ||
| // - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`. | ||
| // Clear `matches` if we hit a new `streamedRow`, as we need to find new matches. | ||
| // Use `bufferedRow` to iterate buffered side to put all matched rows into | ||
| // `matches`. Return true when getting all matched rows. | ||
| // `matches` (`addRowToBuffer`). Return true when getting all matched rows. | ||
| // For `streamedRow` without `matches` (`handleStreamedWithoutMatch`): | ||
| // 1. Inner join: skip the row. | ||
| // 1. Inner and Left Semi join: skip the row. | ||
| // 2. Left/Right Outer join: keep the row and return false (with `matches` being | ||
| // empty). | ||
| ctx.addNewFunction("findNextJoinRows", | ||
|
|
@@ -543,7 +565,7 @@ case class SortMergeJoinExec( | |
| | $handleStreamedWithoutMatch | ||
| | } | ||
| | } else { | ||
| | $matches.add((UnsafeRow) $bufferedRow); | ||
| | $addRowToBuffer | ||
| | $bufferedRow = null; | ||
| | } | ||
| | } while ($streamedRow != null); | ||
|
|
@@ -639,6 +661,8 @@ case class SortMergeJoinExec( | |
| streamedVars ++ bufferedVars | ||
| case RightOuter => | ||
| bufferedVars ++ streamedVars | ||
| case LeftSemi => | ||
| streamedVars | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.doProduce should not take $x as the JoinType") | ||
|
|
@@ -650,8 +674,9 @@ case class SortMergeJoinExec( | |
| val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars) | ||
| val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars) | ||
| // Generate code for condition | ||
| ctx.currentVars = resultVars | ||
| val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) | ||
| ctx.currentVars = streamedVars ++ bufferedVars | ||
| val cond = BindReferences.bindReference( | ||
| condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx) | ||
| // evaluate the columns those used by condition before loop | ||
| val before = | ||
| s""" | ||
|
|
@@ -724,9 +749,32 @@ case class SortMergeJoinExec( | |
| """.stripMargin | ||
| } | ||
|
|
||
| lazy val semiJoin = { | ||
|
||
| val hasOutputRow = ctx.freshName("hasOutputRow") | ||
| s""" | ||
| |while (findNextJoinRows($streamedInput, $bufferedInput)) { | ||
| | ${streamedVarDecl.mkString("\n")} | ||
| | ${beforeLoop.trim} | ||
| | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); | ||
| | boolean $hasOutputRow = false; | ||
| | | ||
| | while (!$hasOutputRow && $iterator.hasNext()) { | ||
| | InternalRow $bufferedRow = (InternalRow) $iterator.next(); | ||
| | ${condCheck.trim} | ||
| | $hasOutputRow = true; | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| | } | ||
| | if (shouldStop()) return; | ||
| |} | ||
| |$eagerCleanup | ||
| """.stripMargin | ||
| } | ||
|
|
||
| joinType match { | ||
| case _: InnerLike => innerJoin | ||
| case LeftOuter | RightOuter => outerJoin | ||
| case LeftSemi => semiJoin | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.doProduce should not take $x as the JoinType") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about moving this branch into the
getInMemoryThresholdside?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1,
lazy valcan probably bedefas the logic is super simpleThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call. Actually the non-code-gen path can also depend on this, so I make it just a
valnow.