Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.joins

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand All @@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{BooleanType, LongType}
Expand All @@ -51,7 +53,7 @@ case class BroadcastHashJoinExec(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
val mode = HashedRelationBroadcastMode(buildBoundKeys)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
Expand All @@ -60,6 +62,66 @@ case class BroadcastHashJoinExec(
}
}

override def outputPartitioning: Partitioning = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val or lazy val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to lazy val

joinType match {
case _: InnerLike =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => PartitioningCollection(expandOutputPartitioning(h))
case c: PartitioningCollection =>
def expand(partitioning: PartitioningCollection): Partitioning = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pull out this inner metdhod and define it outside as private? Also, we need to assign a resonable method name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I overloaded expandOutputPartitioning so that it can take either HashPartitioning or PartitioningCollection and returns PartitioningCollection.

PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => Seq(expand(c))
case other => Seq(other)
})
}
expand(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeys.zip(buildKeys).foreach {
case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
private def expandOutputPartitioning(partitioning: HashPartitioning): Seq[HashPartitioning] = {
def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression],
all: mutable.ListBuffer[Seq[Expression]]): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Thanks for the suggestion.

if (current.isEmpty) {
all += accumulated
} else {
generateExprCombinations(current.tail, accumulated :+ current.head, all)
val mapped = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
if (mapped.isDefined) {
mapped.get.foreach(m => generateExprCombinations(current.tail, accumulated :+ m, all))
}
}
}

val all = mutable.ListBuffer[Seq[Expression]]()
generateExprCombinations(partitioning.expressions, Nil, all)
all.map(HashPartitioning(_, partitioning.numPartitions))
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

Expand Down Expand Up @@ -135,13 +197,13 @@ case class BroadcastHashJoinExec(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) {
// generate the join key as Long
val ev = streamedKeys.head.genCode(ctx)
val ev = streamedBoundKeys.head.genCode(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
(ev, s"${ev.value}.anyNull()")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,30 @@ trait HashJoin extends BaseJoinExec {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
}

private lazy val (buildOutput, streamedOutput) = {
buildSide match {
case BuildLeft => (left.output, right.output)
case BuildRight => (right.output, left.output)
}
}

protected lazy val buildBoundKeys =
bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput)

protected lazy val streamedBoundKeys =
bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput)

protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(buildKeys)
UnsafeProjection.create(buildBoundKeys)

protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(streamedKeys)
UnsafeProjection.create(streamedBoundKeys)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ case class ShuffledHashJoinExec(
val buildTime = longMetric("buildTime")
val start = System.nanoTime()
val context = TaskContext.get()
val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
val relation = HashedRelation(
iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager())
buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
buildDataSize += relation.estimatedSize
// This relation is usually used until the end of task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class AdaptiveQueryExecSuite
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 2)
val smj2 = findTopLevelSortMergeJoin(adaptivePlan)
assert(smj2.size == 2, origPlan.toString)
assert(smj2.size == 1, origPlan.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import scala.reflect.ClassTag

import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, Cast, Expression, Literal, ShiftLeft}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.BROADCAST
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -415,6 +417,192 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs."))
}
}

test("broadcast join where streamed side's output partitioning is HashPartitioning") {
withTable("t1", "t3") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2")
val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3")
df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1")
df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3")
val t1 = spark.table("t1")
val t3 = spark.table("t3")

// join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the
// streamed side (t1) is HashPartitioning (bucketed files).
val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2"))
val plan1 = join1.queryExecution.executedPlan
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b }
assert(broadcastJoins.size == 1)
broadcastJoins(0).outputPartitioning match {
case p: PartitioningCollection =>
assert(p.partitionings.size == 4)
// Verify all the combinations of output partitioning.
Seq(Seq(t1("i1"), t1("j1")),
Seq(t1("i1"), df2("j2")),
Seq(df2("i2"), t1("j1")),
Seq(df2("i2"), df2("j2"))).foreach { expected =>
val expectedExpressions = expected.map(_.expr)
assert(p.partitionings.exists {
case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions)
})
}
case _ => fail()
}

// Join on the column from the broadcasted side (i2, j2) and make sure output partitioning
// is maintained by checking no shuffle exchange is introduced.
val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3"))
val plan2 = join2.queryExecution.executedPlan
assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1)
assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty)

// Validate the data with broadcast join off.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3"))
QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq)
}
}
}
}

test("broadcast join where streamed side's output partitioning is PartitioningCollection") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2")
val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3")
val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4")

// join1 is a sort merge join (shuffle on the both sides).
val join1 = t1.join(t2, t1("i1") === t2("i2"))
val plan1 = join1.queryExecution.executedPlan
assert(collect(plan1) { case s: SortMergeJoinExec => s }.size == 1)
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.size == 2)

// join2 is a broadcast join where t3 is broadcasted. Note that output partitioning on the
// streamed side (join1) is PartitioningCollection (sort merge join)
val join2 = join1.join(t3, join1("i1") === t3("i3"))
val plan2 = join2.queryExecution.executedPlan
assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2)
val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b }
assert(broadcastJoins.size == 1)
broadcastJoins(0).outputPartitioning match {
case p: PartitioningCollection =>
assert(p.partitionings.size == 3)
// Verify all the combinations of output partitioning.
Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected =>
val expectedExpressions = expected.map(_.expr)
assert(p.partitionings.exists {
case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions)
})
}
case _ => fail()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For bretter test error messages,

assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]))
val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection])
...

Or, could you add error messages in fail()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

}

// Join on the column from the broadcasted side (i3) and make sure output partitioning
// is maintained by checking no shuffle exchange is introduced. Note that one extra
// ShuffleExchangeExec is from t4, not from join2.
val join3 = join2.join(t4, join2("i3") === t4("i4"))
val plan3 = join3.queryExecution.executedPlan
assert(collect(plan3) { case s: SortMergeJoinExec => s }.size == 2)
assert(collect(plan3) { case b: BroadcastHashJoinExec => b }.size == 1)
assert(collect(plan3) { case e: ShuffleExchangeExec => e }.size == 3)

// Validate the data with boradcast join off.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = join2.join(t4, join2("i3") === t4("i4"))
QueryTest.sameRows(join3.collect().toSeq, df.collect().toSeq)
}
}
}

test("BroadcastHashJoinExec output partitioning scenarios for inner join") {
val l1 = AttributeReference("l1", LongType)()
val l2 = AttributeReference("l2", LongType)()
val l3 = AttributeReference("l3", LongType)()
val r1 = AttributeReference("r1", LongType)()
val r2 = AttributeReference("r2", LongType)()
val r3 = AttributeReference("r3", LongType)()

// Streamed side has a HashPartitioning.
var bhj = BroadcastHashJoinExec(
leftKeys = Seq(l2, l3),
rightKeys = Seq(r1, r2),
Inner,
BuildRight,
None,
left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2, l3), 1)),
right = DummySparkPlan())
var expected = PartitioningCollection(Seq(
HashPartitioning(Seq(l1, l2, l3), 1),
HashPartitioning(Seq(l1, l2, r2), 1),
HashPartitioning(Seq(l1, r1, l3), 1),
HashPartitioning(Seq(l1, r1, r2), 1)))
assert(bhj.outputPartitioning === expected)

// Streamed side has a PartitioningCollection.
bhj = BroadcastHashJoinExec(
leftKeys = Seq(l1, l2, l3),
rightKeys = Seq(r1, r2, r3),
Inner,
BuildRight,
None,
left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq(
HashPartitioning(Seq(l1, l2), 1), HashPartitioning(Seq(l3), 1)))),
right = DummySparkPlan())
expected = PartitioningCollection(Seq(
HashPartitioning(Seq(l1, l2), 1),
HashPartitioning(Seq(l1, r2), 1),
HashPartitioning(Seq(r1, l2), 1),
HashPartitioning(Seq(r1, r2), 1),
HashPartitioning(Seq(l3), 1),
HashPartitioning(Seq(r3), 1)))
assert(bhj.outputPartitioning === expected)

// Streamed side has a nested PartitioningCollection.
bhj = BroadcastHashJoinExec(
leftKeys = Seq(l1, l2, l3),
rightKeys = Seq(r1, r2, r3),
Inner,
BuildRight,
None,
left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq(
PartitioningCollection(Seq(HashPartitioning(Seq(l1), 1), HashPartitioning(Seq(l2), 1))),
HashPartitioning(Seq(l3), 1)))),
right = DummySparkPlan())
expected = PartitioningCollection(Seq(
PartitioningCollection(Seq(
HashPartitioning(Seq(l1), 1),
HashPartitioning(Seq(r1), 1),
HashPartitioning(Seq(l2), 1),
HashPartitioning(Seq(r2), 1))),
HashPartitioning(Seq(l3), 1),
HashPartitioning(Seq(r3), 1)))
assert(bhj.outputPartitioning === expected)

// One-to-mapping case ("l1" = "r1" AND "l1" = "r2")
bhj = BroadcastHashJoinExec(
leftKeys = Seq(l1, l1),
rightKeys = Seq(r1, r2),
Inner,
BuildRight,
None,
left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)),
right = DummySparkPlan())
expected = PartitioningCollection(Seq(
HashPartitioning(Seq(l1, l2), 1),
HashPartitioning(Seq(r1, l2), 1),
HashPartitioning(Seq(r2, l2), 1)))
assert(bhj.outputPartitioning === expected)
}

private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): Boolean = {
l.length == r.length && l.zip(r).forall { case (e1, e2) => e1.semanticEquals(e2) }
}
}

class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite
Expand Down