Skip to content

Commit d7c1567

Browse files
committed
fix
1 parent 534602b commit d7c1567

2 files changed

Lines changed: 36 additions & 12 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
4343
val result = plan transformDown {
4444
// Start reordering with a joinable item, which is an InnerLike join with conditions.
4545
// Avoid reordering if a join hint is present.
46-
case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE =>
46+
case j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE) =>
4747
reorder(j, j.output)
48-
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint))
49-
if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE =>
48+
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE))
49+
if projectList.forall(_.isInstanceOf[Attribute]) =>
5050
reorder(p, p.output)
5151
}
5252
// After reordering is finished, convert OrderedJoin back to Join.
5353
result transform {
54-
case OrderedJoin(left, right, jt, cond, hint) => Join(left, right, jt, cond, hint)
54+
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE)
5555
}
5656
}
5757
}
@@ -77,25 +77,25 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
7777
*/
7878
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
7979
plan match {
80-
case Join(left, right, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE =>
80+
case Join(left, right, _: InnerLike, Some(cond), JoinHint.NONE) =>
8181
val (leftPlans, leftConditions) = extractInnerJoins(left)
8282
val (rightPlans, rightConditions) = extractInnerJoins(right)
8383
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
8484
leftConditions ++ rightConditions)
85-
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), hint))
86-
if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE =>
85+
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE))
86+
if projectList.forall(_.isInstanceOf[Attribute]) =>
8787
extractInnerJoins(j)
8888
case _ =>
8989
(Seq(plan), Set())
9090
}
9191
}
9292

9393
private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
94-
case j @ Join(left, right, jt: InnerLike, Some(cond), hint) =>
94+
case j @ Join(left, right, jt: InnerLike, Some(cond), JoinHint.NONE) =>
9595
val replacedLeft = replaceWithOrderedJoin(left)
9696
val replacedRight = replaceWithOrderedJoin(right)
97-
OrderedJoin(replacedLeft, replacedRight, jt, Some(cond), hint)
98-
case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) =>
97+
OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
98+
case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) =>
9999
p.copy(child = replaceWithOrderedJoin(j))
100100
case _ =>
101101
plan
@@ -107,8 +107,7 @@ case class OrderedJoin(
107107
left: LogicalPlan,
108108
right: LogicalPlan,
109109
joinType: JoinType,
110-
condition: Option[Expression],
111-
hint: JoinHint) extends BinaryNode {
110+
condition: Option[Expression]) extends BinaryNode {
112111
override def output: Seq[Attribute] = left.output ++ right.output
113112
}
114113

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,14 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
312312
.join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
313313

314314
assertEqualPlans(originalPlan2, originalPlan2)
315+
316+
val originalPlan3 =
317+
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
318+
.join(t4).hint("broadcast")
319+
.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
320+
.join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100")))
321+
322+
assertEqualPlans(originalPlan3, originalPlan3)
315323
}
316324

317325
test("reorder below and above the hint node") {
@@ -342,6 +350,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
342350
.join(t4.hint("broadcast"))
343351

344352
assertEqualPlans(originalPlan2, bestPlan2)
353+
354+
val originalPlan3 =
355+
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
356+
.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
357+
.hint("broadcast")
358+
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
359+
.join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100")))
360+
361+
val bestPlan3 =
362+
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
363+
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
364+
.select(outputsOf(t1, t2, t3): _*)
365+
.hint("broadcast")
366+
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
367+
.join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100")))
368+
369+
assertEqualPlans(originalPlan3, bestPlan3)
345370
}
346371

347372
private def assertEqualPlans(

0 commit comments

Comments
 (0)