Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,15 @@ object SQLConf {
.checkValue(_ > 0, "The difference must be positive.")
.createWithDefault(4)

val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT =
buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit")
.doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " +
"This configuration is applicable only for inner joins.")
.version("3.1.0")
.intConf
.checkValue(_ > 0, "The value must be positive.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we allow 0 which disables this feature? can be useful when we want to benchmark the improvement of this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 should retain the same behavior (although it creates PartitioningCollection with the original HashPartitioning). I added an explicit check with 0 to disable this feature. Thanks.

.createWithDefault(8)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.joins

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand All @@ -26,9 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, LongType}

/**
Expand All @@ -51,7 +54,7 @@ case class BroadcastHashJoinExec(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
val mode = HashedRelationBroadcastMode(buildBoundKeys)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
Expand All @@ -60,6 +63,81 @@ case class BroadcastHashJoinExec(
}
}

override lazy val outputPartitioning: Partitioning = {
joinType match {
case Inner =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeys.zip(buildKeys).foreach {
case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning collection recursively.
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
}

// Expands the given hash partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
val maxNumCombinations = sqlContext.conf.getConf(
SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT)
var currentNumCombinations = 0

def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (currentNumCombinations > maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeys = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeys.map { bKeys =>
bKeys.flatMap { bKey =>
if (currentNumCombinations < maxNumCombinations) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this if? I think generateExprCombinations will return Nil if hitting the upper bound.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to avoid unnecessary recursion (+ not creating new Seq, etc.), but I removed the check for simplicity.

generateExprCombinations(current.tail, accumulated :+ bKey)
} else {
Nil
}
}
}.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(partitioning.expressions, Nil)
.map(HashPartitioning(_, partitioning.numPartitions)))
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

Expand Down Expand Up @@ -135,13 +213,13 @@ case class BroadcastHashJoinExec(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) {
// generate the join key as Long
val ev = streamedKeys.head.genCode(ctx)
val ev = streamedBoundKeys.head.genCode(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
(ev, s"${ev.value}.anyNull()")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,30 @@ trait HashJoin extends BaseJoinExec {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
}

private lazy val (buildOutput, streamedOutput) = {
buildSide match {
case BuildLeft => (left.output, right.output)
case BuildRight => (right.output, left.output)
}
}

protected lazy val buildBoundKeys =
bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput)

protected lazy val streamedBoundKeys =
bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput)

protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(buildKeys)
UnsafeProjection.create(buildBoundKeys)

protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(streamedKeys)
UnsafeProjection.create(streamedBoundKeys)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ case class ShuffledHashJoinExec(
val buildTime = longMetric("buildTime")
val start = System.nanoTime()
val context = TaskContext.get()
val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
val relation = HashedRelation(
iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager())
buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
buildDataSize += relation.estimatedSize
// This relation is usually used until the end of task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ class AdaptiveQueryExecSuite
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 2)
val smj2 = findTopLevelSortMergeJoin(adaptivePlan)
assert(smj2.size == 2, origPlan.toString)
assert(smj2.size == 1, origPlan.toString)
}
}

Expand Down
Loading