-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning #28676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning #28676
Changes from 11 commits
93947ab
225e250
985834b
683a705
6fa753f
dedce0c
488e051
cac3829
febc402
1ea931b
c5f4803
63fdb0f
126ee53
794890f
afa5aca
51187dc
80df4dc
ba19acb
9caeecd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution.joins | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ | |
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} | ||
| import org.apache.spark.sql.execution.metric.SQLMetrics | ||
| import org.apache.spark.sql.types.{BooleanType, LongType} | ||
|
|
@@ -51,7 +53,7 @@ case class BroadcastHashJoinExec( | |
| "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) | ||
|
|
||
| override def requiredChildDistribution: Seq[Distribution] = { | ||
| val mode = HashedRelationBroadcastMode(buildKeys) | ||
| val mode = HashedRelationBroadcastMode(buildBoundKeys) | ||
| buildSide match { | ||
| case BuildLeft => | ||
| BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil | ||
|
|
@@ -60,6 +62,66 @@ case class BroadcastHashJoinExec( | |
| } | ||
| } | ||
|
|
||
| override def outputPartitioning: Partitioning = { | ||
|
||
| joinType match { | ||
| case _: InnerLike => | ||
imback82 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| streamedPlan.outputPartitioning match { | ||
| case h: HashPartitioning => PartitioningCollection(expandOutputPartitioning(h)) | ||
| case c: PartitioningCollection => | ||
| def expand(partitioning: PartitioningCollection): Partitioning = { | ||
|
||
| PartitioningCollection(partitioning.partitionings.flatMap { | ||
| case h: HashPartitioning => expandOutputPartitioning(h) | ||
| case c: PartitioningCollection => Seq(expand(c)) | ||
| case other => Seq(other) | ||
| }) | ||
| } | ||
| expand(c) | ||
| case other => other | ||
| } | ||
| case _ => streamedPlan.outputPartitioning | ||
| } | ||
| } | ||
|
|
||
| // An one-to-many mapping from a streamed key to build keys. | ||
| private lazy val streamedKeyToBuildKeyMapping = { | ||
| val mapping = mutable.Map.empty[Expression, Seq[Expression]] | ||
| streamedKeys.zip(buildKeys).foreach { | ||
| case (streamedKey, buildKey) => | ||
| val key = streamedKey.canonicalized | ||
| mapping.get(key) match { | ||
| case Some(v) => mapping.put(key, v :+ buildKey) | ||
| case None => mapping.put(key, Seq(buildKey)) | ||
| } | ||
| } | ||
| mapping.toMap | ||
| } | ||
|
|
||
| // Expands the given partitioning by substituting streamed keys with build keys. | ||
| // For example, if the expressions for the given partitioning are Seq("a", "b", "c") | ||
| // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), | ||
| // the expanded partitioning will have the following expressions: | ||
| // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). | ||
| private def expandOutputPartitioning(partitioning: HashPartitioning): Seq[HashPartitioning] = { | ||
| def generateExprCombinations( | ||
| current: Seq[Expression], | ||
| accumulated: Seq[Expression], | ||
| all: mutable.ListBuffer[Seq[Expression]]): Unit = { | ||
|
||
| if (current.isEmpty) { | ||
| all += accumulated | ||
| } else { | ||
| generateExprCombinations(current.tail, accumulated :+ current.head, all) | ||
| val mapped = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) | ||
| if (mapped.isDefined) { | ||
| mapped.get.foreach(m => generateExprCombinations(current.tail, accumulated :+ m, all)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| val all = mutable.ListBuffer[Seq[Expression]]() | ||
| generateExprCombinations(partitioning.expressions, Nil, all) | ||
| all.map(HashPartitioning(_, partitioning.numPartitions)) | ||
| } | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = { | ||
| val numOutputRows = longMetric("numOutputRows") | ||
|
|
||
|
|
@@ -135,13 +197,13 @@ case class BroadcastHashJoinExec( | |
| ctx: CodegenContext, | ||
| input: Seq[ExprCode]): (ExprCode, String) = { | ||
| ctx.currentVars = input | ||
| if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { | ||
| if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { | ||
| // generate the join key as Long | ||
| val ev = streamedKeys.head.genCode(ctx) | ||
| val ev = streamedBoundKeys.head.genCode(ctx) | ||
| (ev, ev.isNull) | ||
| } else { | ||
| // generate the join key as UnsafeRow | ||
| val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) | ||
| val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) | ||
| (ev, s"${ev.value}.anyNull()") | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,13 +21,15 @@ import scala.reflect.ClassTag | |
|
|
||
| import org.apache.spark.AccumulatorSuite | ||
| import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} | ||
| import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} | ||
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, Cast, Expression, Literal, ShiftLeft} | ||
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} | ||
| import org.apache.spark.sql.catalyst.plans.Inner | ||
| import org.apache.spark.sql.catalyst.plans.logical.BROADCAST | ||
| import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} | ||
| import org.apache.spark.sql.execution.{DummySparkPlan, SparkPlan, WholeStageCodegenExec} | ||
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} | ||
| import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec | ||
| import org.apache.spark.sql.execution.exchange.EnsureRequirements | ||
| import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.test.SQLTestUtils | ||
|
|
@@ -415,6 +417,192 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils | |
| assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs.")) | ||
| } | ||
| } | ||
|
|
||
| test("broadcast join where streamed side's output partitioning is HashPartitioning") { | ||
| withTable("t1", "t3") { | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { | ||
| val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") | ||
| val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") | ||
| val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") | ||
| df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1") | ||
| df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3") | ||
| val t1 = spark.table("t1") | ||
| val t3 = spark.table("t3") | ||
|
|
||
| // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the | ||
| // streamed side (t1) is HashPartitioning (bucketed files). | ||
| val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) | ||
| val plan1 = join1.queryExecution.executedPlan | ||
| assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) | ||
| val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } | ||
| assert(broadcastJoins.size == 1) | ||
| broadcastJoins(0).outputPartitioning match { | ||
| case p: PartitioningCollection => | ||
| assert(p.partitionings.size == 4) | ||
| // Verify all the combinations of output partitioning. | ||
| Seq(Seq(t1("i1"), t1("j1")), | ||
| Seq(t1("i1"), df2("j2")), | ||
| Seq(df2("i2"), t1("j1")), | ||
| Seq(df2("i2"), df2("j2"))).foreach { expected => | ||
| val expectedExpressions = expected.map(_.expr) | ||
| assert(p.partitionings.exists { | ||
| case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) | ||
| }) | ||
| } | ||
| case _ => fail() | ||
| } | ||
|
|
||
| // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning | ||
| // is maintained by checking no shuffle exchange is introduced. | ||
| val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) | ||
| val plan2 = join2.queryExecution.executedPlan | ||
| assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) | ||
| assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) | ||
| assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) | ||
|
|
||
| // Validate the data with broadcast join off. | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
| val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) | ||
| QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) | ||
imback82 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("broadcast join where streamed side's output partitioning is PartitioningCollection") { | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { | ||
| val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") | ||
| val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2") | ||
| val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3") | ||
| val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4") | ||
|
|
||
| // join1 is a sort merge join (shuffle on the both sides). | ||
| val join1 = t1.join(t2, t1("i1") === t2("i2")) | ||
| val plan1 = join1.queryExecution.executedPlan | ||
| assert(collect(plan1) { case s: SortMergeJoinExec => s }.size == 1) | ||
| assert(collect(plan1) { case e: ShuffleExchangeExec => e }.size == 2) | ||
|
|
||
| // join2 is a broadcast join where t3 is broadcasted. Note that output partitioning on the | ||
| // streamed side (join1) is PartitioningCollection (sort merge join) | ||
| val join2 = join1.join(t3, join1("i1") === t3("i3")) | ||
| val plan2 = join2.queryExecution.executedPlan | ||
| assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) | ||
| assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2) | ||
| val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b } | ||
| assert(broadcastJoins.size == 1) | ||
| broadcastJoins(0).outputPartitioning match { | ||
| case p: PartitioningCollection => | ||
| assert(p.partitionings.size == 3) | ||
| // Verify all the combinations of output partitioning. | ||
| Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected => | ||
| val expectedExpressions = expected.map(_.expr) | ||
| assert(p.partitionings.exists { | ||
| case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) | ||
| }) | ||
| } | ||
| case _ => fail() | ||
|
||
| } | ||
|
|
||
| // Join on the column from the broadcasted side (i3) and make sure output partitioning | ||
| // is maintained by checking no shuffle exchange is introduced. Note that one extra | ||
| // ShuffleExchangeExec is from t4, not from join2. | ||
| val join3 = join2.join(t4, join2("i3") === t4("i4")) | ||
| val plan3 = join3.queryExecution.executedPlan | ||
| assert(collect(plan3) { case s: SortMergeJoinExec => s }.size == 2) | ||
| assert(collect(plan3) { case b: BroadcastHashJoinExec => b }.size == 1) | ||
| assert(collect(plan3) { case e: ShuffleExchangeExec => e }.size == 3) | ||
|
|
||
| // Validate the data with boradcast join off. | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
| val df = join2.join(t4, join2("i3") === t4("i4")) | ||
| QueryTest.sameRows(join3.collect().toSeq, df.collect().toSeq) | ||
imback82 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| test("BroadcastHashJoinExec output partitioning scenarios for inner join") { | ||
| val l1 = AttributeReference("l1", LongType)() | ||
| val l2 = AttributeReference("l2", LongType)() | ||
| val l3 = AttributeReference("l3", LongType)() | ||
| val r1 = AttributeReference("r1", LongType)() | ||
| val r2 = AttributeReference("r2", LongType)() | ||
| val r3 = AttributeReference("r3", LongType)() | ||
|
|
||
| // Streamed side has a HashPartitioning. | ||
| var bhj = BroadcastHashJoinExec( | ||
| leftKeys = Seq(l2, l3), | ||
| rightKeys = Seq(r1, r2), | ||
| Inner, | ||
| BuildRight, | ||
| None, | ||
| left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2, l3), 1)), | ||
| right = DummySparkPlan()) | ||
| var expected = PartitioningCollection(Seq( | ||
| HashPartitioning(Seq(l1, l2, l3), 1), | ||
| HashPartitioning(Seq(l1, l2, r2), 1), | ||
| HashPartitioning(Seq(l1, r1, l3), 1), | ||
| HashPartitioning(Seq(l1, r1, r2), 1))) | ||
| assert(bhj.outputPartitioning === expected) | ||
|
|
||
| // Streamed side has a PartitioningCollection. | ||
| bhj = BroadcastHashJoinExec( | ||
| leftKeys = Seq(l1, l2, l3), | ||
| rightKeys = Seq(r1, r2, r3), | ||
| Inner, | ||
| BuildRight, | ||
| None, | ||
| left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq( | ||
| HashPartitioning(Seq(l1, l2), 1), HashPartitioning(Seq(l3), 1)))), | ||
| right = DummySparkPlan()) | ||
| expected = PartitioningCollection(Seq( | ||
| HashPartitioning(Seq(l1, l2), 1), | ||
| HashPartitioning(Seq(l1, r2), 1), | ||
| HashPartitioning(Seq(r1, l2), 1), | ||
| HashPartitioning(Seq(r1, r2), 1), | ||
| HashPartitioning(Seq(l3), 1), | ||
| HashPartitioning(Seq(r3), 1))) | ||
| assert(bhj.outputPartitioning === expected) | ||
|
|
||
| // Streamed side has a nested PartitioningCollection. | ||
| bhj = BroadcastHashJoinExec( | ||
| leftKeys = Seq(l1, l2, l3), | ||
| rightKeys = Seq(r1, r2, r3), | ||
| Inner, | ||
| BuildRight, | ||
| None, | ||
| left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq( | ||
| PartitioningCollection(Seq(HashPartitioning(Seq(l1), 1), HashPartitioning(Seq(l2), 1))), | ||
| HashPartitioning(Seq(l3), 1)))), | ||
| right = DummySparkPlan()) | ||
| expected = PartitioningCollection(Seq( | ||
| PartitioningCollection(Seq( | ||
| HashPartitioning(Seq(l1), 1), | ||
| HashPartitioning(Seq(r1), 1), | ||
| HashPartitioning(Seq(l2), 1), | ||
| HashPartitioning(Seq(r2), 1))), | ||
| HashPartitioning(Seq(l3), 1), | ||
| HashPartitioning(Seq(r3), 1))) | ||
| assert(bhj.outputPartitioning === expected) | ||
|
|
||
| // One-to-mapping case ("l1" = "r1" AND "l1" = "r2") | ||
| bhj = BroadcastHashJoinExec( | ||
| leftKeys = Seq(l1, l1), | ||
| rightKeys = Seq(r1, r2), | ||
| Inner, | ||
| BuildRight, | ||
| None, | ||
| left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)), | ||
| right = DummySparkPlan()) | ||
| expected = PartitioningCollection(Seq( | ||
| HashPartitioning(Seq(l1, l2), 1), | ||
| HashPartitioning(Seq(r1, l2), 1), | ||
| HashPartitioning(Seq(r2, l2), 1))) | ||
| assert(bhj.outputPartitioning === expected) | ||
| } | ||
|
|
||
| private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): Boolean = { | ||
| l.length == r.length && l.zip(r).forall { case (e1, e2) => e1.semanticEquals(e2) } | ||
| } | ||
| } | ||
|
|
||
| class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.