Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
case _: InnerLike | LeftOuter | LeftSemi => ((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
Expand All @@ -365,7 +365,7 @@ case class SortMergeJoinExec(
private lazy val bufferedOutput = bufferedPlan.output

override def supportCodegen: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter => true
case _: InnerLike | LeftOuter | RightOuter | LeftSemi => true
case _ => false
}

Expand Down Expand Up @@ -424,8 +424,18 @@ case class SortMergeJoinExec(
// A list to hold all matched rows from buffered side.
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName

// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
val onlyBufferFirstMatchedRow = (joinType, condition) match {
case (LeftSemi, None) => true
case _ => false
}
val inMemoryThreshold =
if (onlyBufferFirstMatchedRow) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving this branch into the getInMemoryThreshold side?


  // Flag to only buffer first matched row, to avoid buffering unnecessary rows.
  private lazy val onlyBufferFirstMatchedRow = (joinType, condition) match {
    case (LeftSemi, None) => true
    case _ => false
  }

  private def getInMemoryThreshold: Int = {
    if (onlyBufferFirstMatchedRow) {
      1
    } else {
      sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
    }
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, lazy val can probably be def as the logic is super simple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Actually the non-code-gen path can also depend on this, so I make it just a val now.

1
} else {
getInMemoryThreshold
}
val spillThreshold = getSpillThreshold
val inMemoryThreshold = getInMemoryThreshold

// Inline mutable state since not many join operations in a task
val matches = ctx.addMutableState(clsName, "matches",
Expand All @@ -435,7 +445,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed rows has any NULL keys.
val handleStreamedAnyNull = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"""
|$streamedRow = null;
Expand All @@ -457,7 +467,7 @@ case class SortMergeJoinExec(

// Handle the case when streamed keys has no match with buffered side.
val handleStreamedWithoutMatch = joinType match {
case _: InnerLike =>
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"$streamedRow = null;"
case LeftOuter | RightOuter =>
Expand All @@ -468,6 +478,17 @@ case class SortMergeJoinExec(
s"SortMergeJoin.genScanner should not take $x as the JoinType")
}

val addRowToBuffer =
if (onlyBufferFirstMatchedRow) {
s"""
|if ($matches.isEmpty()) {
| $matches.add((UnsafeRow) $bufferedRow);
|}
""".stripMargin
} else {
s"$matches.add((UnsafeRow) $bufferedRow);"
}

// Generate a function to scan both streamed and buffered sides to find a match.
// Return whether a match is found.
//
Expand All @@ -483,17 +504,18 @@ case class SortMergeJoinExec(
// The function has the following step:
// - Step 1: Find the next `streamedRow` with non-null join keys.
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
// 1. Inner join: skip the row. `matches` will be cleared later when hitting the
// next `streamedRow` with non-null join keys.
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
// hitting the next `streamedRow` with non-null join
// keys.
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
// and return false.
//
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
// Use `bufferedRow` to iterate buffered side to put all matched rows into
// `matches`. Return true when getting all matched rows.
// `matches` (`addRowToBuffer`). Return true when getting all matched rows.
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
// 1. Inner join: skip the row.
// 1. Inner and Left Semi join: skip the row.
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
// empty).
ctx.addNewFunction("findNextJoinRows",
Expand Down Expand Up @@ -543,7 +565,7 @@ case class SortMergeJoinExec(
| $handleStreamedWithoutMatch
| }
| } else {
| $matches.add((UnsafeRow) $bufferedRow);
| $addRowToBuffer
| $bufferedRow = null;
| }
| } while ($streamedRow != null);
Expand Down Expand Up @@ -639,6 +661,8 @@ case class SortMergeJoinExec(
streamedVars ++ bufferedVars
case RightOuter =>
bufferedVars ++ streamedVars
case LeftSemi =>
streamedVars
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
Expand All @@ -650,8 +674,9 @@ case class SortMergeJoinExec(
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
// Generate code for condition
ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
ctx.currentVars = streamedVars ++ bufferedVars
val cond = BindReferences.bindReference(
condition.get, streamedPlan.output ++ bufferedPlan.output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before =
s"""
Expand Down Expand Up @@ -724,9 +749,32 @@ case class SortMergeJoinExec(
""".stripMargin
}

lazy val semiJoin = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extracting this block as a private method like codegenXXXX just like HashJoin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - yes I was thinking at the first place but worried about number of parameters to be too many. Refined the code a bit and updated now.

val hasOutputRow = ctx.freshName("hasOutputRow")
s"""
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| boolean $hasOutputRow = false;
|
| while (!$hasOutputRow && $iterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $hasOutputRow = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

joinType match {
case _: InnerLike => innerJoin
case LeftOuter | RightOuter => outerJoin
case LeftSemi => semiJoin
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,28 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
Row(null, null, 6), Row(null, null, 7), Row(null, null, 8), Row(null, null, 9)))
}

test("Left Semi SortMergeJoin should be included in WholeStageCodegen") {
val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
val df3 = spark.range(6).select($"id".as("k3"))

// test one left semi sort merge join
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3)))

// test two left semi sort merge joins
val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi")
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
WholeStageCodegenExec(_ : SortMergeJoinExec) => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
}

test("Inner/Cross BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
}

test(s"$testName using SortMergeJoin") {
testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
Expand Down