Skip to content

Commit 6f78684

Browse files
committed
[SPARK-27485] EnsureRequirements.reorder should handle duplicate expressions gracefully
1 parent f241fc7 commit 6f78684

3 files changed

Lines changed: 87 additions & 33 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.physical._
2525
import org.apache.spark.sql.catalyst.rules.Rule
2626
import org.apache.spark.sql.execution._
27-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
28-
SortMergeJoinExec}
27+
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
2928
import org.apache.spark.sql.internal.SQLConf
3029

3130
/**
@@ -117,25 +116,41 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
117116
}
118117

119118
private def reorder(
120-
leftKeys: Seq[Expression],
121-
rightKeys: Seq[Expression],
119+
leftKeys: IndexedSeq[Expression],
120+
rightKeys: IndexedSeq[Expression],
122121
expectedOrderOfKeys: Seq[Expression],
123122
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
124-
val leftKeysBuffer = ArrayBuffer[Expression]()
125-
val rightKeysBuffer = ArrayBuffer[Expression]()
126-
val pickedIndexes = mutable.Set[Int]()
127-
val keysAndIndexes = currentOrderOfKeys.zipWithIndex
128-
129-
expectedOrderOfKeys.foreach(expression => {
130-
val index = keysAndIndexes.find { case (e, idx) =>
131-
// As we may have the same key used many times, we need to filter out its occurrence we
132-
// have already used.
133-
e.semanticEquals(expression) && !pickedIndexes.contains(idx)
134-
}.map(_._2).get
135-
pickedIndexes += index
136-
leftKeysBuffer.append(leftKeys(index))
137-
rightKeysBuffer.append(rightKeys(index))
138-
})
123+
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
124+
return (leftKeys, rightKeys)
125+
}
126+
127+
// Build a lookup between an expression and the positions its holds in the current key seq.
128+
val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet]
129+
currentOrderOfKeys.zipWithIndex.foreach {
130+
case (key, index) =>
131+
keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet.empty).add(index)
132+
}
133+
134+
// Reorder the keys.
135+
val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size)
136+
val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size)
137+
val iterator = expectedOrderOfKeys.iterator
138+
while (iterator.hasNext) {
139+
// Lookup the current index of this key.
140+
keyToIndexMap.get(iterator.next().canonicalized) match {
141+
case Some(indices) if indices.nonEmpty =>
142+
// Take the first available index from the map.
143+
val index = indices.firstKey
144+
indices.remove(index)
145+
146+
// Add the keys for that index to the reordered keys.
147+
leftKeysBuffer += leftKeys(index)
148+
rightKeysBuffer += rightKeys(index)
149+
case _ =>
150+
// The expression cannot be found, or we have exhausted all indices for that expression.
151+
return (leftKeys, rightKeys)
152+
}
153+
}
139154
(leftKeysBuffer, rightKeysBuffer)
140155
}
141156

@@ -145,20 +160,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
145160
leftPartitioning: Partitioning,
146161
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
147162
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
148-
leftPartitioning match {
149-
case HashPartitioning(leftExpressions, _)
150-
if leftExpressions.length == leftKeys.length &&
151-
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
152-
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
153-
154-
case _ => rightPartitioning match {
155-
case HashPartitioning(rightExpressions, _)
156-
if rightExpressions.length == rightKeys.length &&
157-
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
158-
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
159-
160-
case _ => (leftKeys, rightKeys)
161-
}
163+
(leftPartitioning, rightPartitioning) match {
164+
case (HashPartitioning(leftExpressions, _), _) =>
165+
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
166+
case (_, HashPartitioning(rightExpressions, _)) =>
167+
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
168+
case _ =>
169+
(leftKeys, rightKeys)
162170
}
163171
} else {
164172
(leftKeys, rightKeys)

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,26 @@ class JoinSuite extends QueryTest with SharedSQLContext {
897897
}
898898
}
899899

900+
test("SPARK-27485: EnsureRequirements should not fail join with duplicate keys") {
901+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2",
902+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
903+
val tbl_a = spark.range(40)
904+
.select($"id" as "x", $"id" % 10 as "y")
905+
.repartition(2, $"x", $"y", $"x")
906+
.as("tbl_a")
907+
908+
val tbl_b = spark.range(20)
909+
.select($"id" as "x", $"id" % 2 as "y1", $"id" % 20 as "y2")
910+
.as("tbl_b")
911+
912+
val res = tbl_a
913+
.join(tbl_b,
914+
$"tbl_a.x" === $"tbl_b.x" && $"tbl_a.y" === $"tbl_b.y1" && $"tbl_a.y" === $"tbl_b.y2")
915+
.select($"tbl_a.x")
916+
checkAnswer(res, Row(0L) :: Row(1L) :: Nil)
917+
}
918+
}
919+
900920
test("SPARK-26352: join reordering should not change the order of columns") {
901921
withTable("tab1", "tab2", "tab3") {
902922
spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,32 @@ class PlannerSuite extends SharedSQLContext {
696696
}
697697
}
698698

699+
test("SPARK-27485: EnsureRequirements.reorder should handle duplicate expressions") {
700+
val plan1 = DummySparkPlan(
701+
outputPartitioning = HashPartitioning(exprA :: exprB :: exprA :: Nil, 5))
702+
val plan2 = DummySparkPlan()
703+
val smjExec = SortMergeJoinExec(
704+
leftKeys = exprA :: exprB :: exprB :: Nil,
705+
rightKeys = exprA :: exprC :: exprC :: Nil,
706+
joinType = Inner,
707+
condition = None,
708+
left = plan1,
709+
right = plan2)
710+
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
711+
outputPlan match {
712+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
713+
SortExec(_, _,
714+
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _), _),
715+
SortExec(_, _,
716+
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _), _)) =>
717+
assert(leftKeys === smjExec.leftKeys)
718+
assert(rightKeys === smjExec.rightKeys)
719+
assert(leftKeys === leftPartitioningExpressions)
720+
assert(rightKeys === rightPartitioningExpressions)
721+
case _ => fail(outputPlan.toString)
722+
}
723+
}
724+
699725
test("SPARK-24500: create union with stream of children") {
700726
val df = Union(Stream(
701727
Range(1, 1, 1, 1),

0 commit comments

Comments
 (0)