-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning #28676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning #28676
Changes from 15 commits
93947ab
225e250
985834b
683a705
6fa753f
dedce0c
488e051
cac3829
febc402
1ea931b
c5f4803
63fdb0f
126ee53
794890f
afa5aca
51187dc
80df4dc
ba19acb
9caeecd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2665,6 +2665,15 @@ object SQLConf { | |
| .checkValue(_ > 0, "The difference must be positive.") | ||
| .createWithDefault(4) | ||
|
|
||
| val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT = | ||
| buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") | ||
| .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + | ||
| "This configuration is applicable only for inner joins.") | ||
| .version("3.1.0") | ||
| .intConf | ||
| .checkValue(_ > 0, "The value must be positive.") | ||
|
||
| .createWithDefault(8) | ||
|
|
||
| /** | ||
| * Holds information about keys that have been deprecated. | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution.joins | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -26,9 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ | |
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} | ||
| import org.apache.spark.sql.execution.metric.SQLMetrics | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types.{BooleanType, LongType} | ||
|
|
||
| /** | ||
|
|
@@ -51,7 +54,7 @@ case class BroadcastHashJoinExec( | |
| "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) | ||
|
|
||
| override def requiredChildDistribution: Seq[Distribution] = { | ||
| val mode = HashedRelationBroadcastMode(buildKeys) | ||
| val mode = HashedRelationBroadcastMode(buildBoundKeys) | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| buildSide match { | ||
| case BuildLeft => | ||
| BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil | ||
|
|
@@ -60,6 +63,81 @@ case class BroadcastHashJoinExec( | |
| } | ||
| } | ||
|
|
||
| override lazy val outputPartitioning: Partitioning = { | ||
| joinType match { | ||
| case Inner => | ||
| streamedPlan.outputPartitioning match { | ||
| case h: HashPartitioning => expandOutputPartitioning(h) | ||
| case c: PartitioningCollection => expandOutputPartitioning(c) | ||
| case other => other | ||
| } | ||
| case _ => streamedPlan.outputPartitioning | ||
| } | ||
| } | ||
|
|
||
| // An one-to-many mapping from a streamed key to build keys. | ||
| private lazy val streamedKeyToBuildKeyMapping = { | ||
| val mapping = mutable.Map.empty[Expression, Seq[Expression]] | ||
| streamedKeys.zip(buildKeys).foreach { | ||
| case (streamedKey, buildKey) => | ||
| val key = streamedKey.canonicalized | ||
| mapping.get(key) match { | ||
| case Some(v) => mapping.put(key, v :+ buildKey) | ||
| case None => mapping.put(key, Seq(buildKey)) | ||
| } | ||
| } | ||
| mapping.toMap | ||
| } | ||
|
|
||
| // Expands the given partitioning collection recursively. | ||
| private def expandOutputPartitioning( | ||
| partitioning: PartitioningCollection): PartitioningCollection = { | ||
| PartitioningCollection(partitioning.partitionings.flatMap { | ||
| case h: HashPartitioning => expandOutputPartitioning(h).partitionings | ||
| case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| case other => Seq(other) | ||
| }) | ||
| } | ||
|
|
||
| // Expands the given hash partitioning by substituting streamed keys with build keys. | ||
| // For example, if the expressions for the given partitioning are Seq("a", "b", "c") | ||
| // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), | ||
| // the expanded partitioning will have the following expressions: | ||
| // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). | ||
| // The expanded expressions are returned as PartitioningCollection. | ||
| private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { | ||
| val maxNumCombinations = sqlContext.conf.getConf( | ||
| SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) | ||
| var currentNumCombinations = 0 | ||
|
|
||
| def generateExprCombinations( | ||
| current: Seq[Expression], | ||
| accumulated: Seq[Expression]): Seq[Seq[Expression]] = { | ||
| if (currentNumCombinations > maxNumCombinations) { | ||
| Nil | ||
| } else if (current.isEmpty) { | ||
| currentNumCombinations += 1 | ||
| Seq(accumulated) | ||
| } else { | ||
| val buildKeys = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) | ||
imback82 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| generateExprCombinations(current.tail, accumulated :+ current.head) ++ | ||
| buildKeys.map { bKeys => | ||
| bKeys.flatMap { bKey => | ||
| if (currentNumCombinations < maxNumCombinations) { | ||
|
||
| generateExprCombinations(current.tail, accumulated :+ bKey) | ||
| } else { | ||
| Nil | ||
| } | ||
| } | ||
| }.getOrElse(Nil) | ||
| } | ||
| } | ||
|
|
||
| PartitioningCollection( | ||
| generateExprCombinations(partitioning.expressions, Nil) | ||
| .map(HashPartitioning(_, partitioning.numPartitions))) | ||
| } | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = { | ||
| val numOutputRows = longMetric("numOutputRows") | ||
|
|
||
|
|
@@ -135,13 +213,13 @@ case class BroadcastHashJoinExec( | |
| ctx: CodegenContext, | ||
| input: Seq[ExprCode]): (ExprCode, String) = { | ||
| ctx.currentVars = input | ||
| if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { | ||
| if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { | ||
| // generate the join key as Long | ||
| val ev = streamedKeys.head.genCode(ctx) | ||
| val ev = streamedBoundKeys.head.genCode(ctx) | ||
| (ev, ev.isNull) | ||
| } else { | ||
| // generate the join key as UnsafeRow | ||
| val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) | ||
| val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) | ||
| (ev, s"${ev.value}.anyNull()") | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.