From 6949db5ba042d22191479f08e0b5e8127bf15aab Mon Sep 17 00:00:00 2001 From: maryannxue Date: Sun, 10 Feb 2019 13:24:52 -0600 Subject: [PATCH 1/4] fix --- .../optimizer/CostBasedJoinReorder.scala | 15 +-- .../org/apache/spark/sql/JoinHintSuite.scala | 92 ++++++++++--------- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 6540e95b01e3..4123238e4ab5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -51,7 +51,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } // After reordering is finished, convert OrderedJoin back to Join. result transform { - case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) + case OrderedJoin(left, right, jt, cond, hint) => Join(left, right, jt, cond, hint) } } } @@ -77,13 +77,13 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond), _) => + case Join(left, right, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) - if projectList.forall(_.isInstanceOf[Attribute]) => + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), hint)) + if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => extractInnerJoins(j) case _ => (Seq(plan), Set()) @@ -91,10 +91,10 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond), _) => + case j @ Join(left, right, jt: InnerLike, Some(cond), hint) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) - OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond), hint) case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => @@ -107,7 +107,8 @@ case class OrderedJoin( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + condition: Option[Expression], + hint: JoinHint) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 30a3d54fd833..67f0f1a6fd23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -102,58 +102,60 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { } test("hints prevent join reorder") { - withTempView("a", "b", "c") { - df1.createOrReplaceTempView("a") - df2.createOrReplaceTempView("b") - df3.createOrReplaceTempView("c") - verifyJoinHint( - sql("select /*+ broadcast(a, c)*/ * from a, b, c " + - "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint( - None, - Some(HintInfo(broadcast = true))) :: - JoinHint( - Some(HintInfo(broadcast = true)), - None):: Nil - ) - verifyJoinHint( - sql("select /*+ broadcast(a, c)*/ * from a, c, b " + - "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint.NONE :: + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + df3.createOrReplaceTempView("c") + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( - Some(HintInfo(broadcast = true)), - Some(HintInfo(broadcast = true))):: Nil - ) - verifyJoinHint( - sql("select /*+ broadcast(b, c)*/ * from a, c, b " + - "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint( - None, - Some(HintInfo(broadcast = true))) :: + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(broadcast = true)), + Some(HintInfo(broadcast = true))) :: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(b, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, - Some(HintInfo(broadcast = true))):: Nil - ) - - verifyJoinHint( - df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") - .join(df3, 'b1 === 'c1 && 'a1 < 10), - JoinHint( - Some(HintInfo(broadcast = true)), - None) :: - JoinHint.NONE:: Nil - ) + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: Nil + ) - verifyJoinHint( - df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") - .join(df3, 'b1 === 'c1 && 'a1 < 10) - .join(df, 'b1 === 'id), - JoinHint.NONE :: + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10), JoinHint( Some(HintInfo(broadcast = true)), None) :: - JoinHint.NONE:: Nil - ) + JoinHint.NONE :: Nil + ) + + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10) + .join(df, 'b1 === 'id), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE :: Nil + ) + } } } From eb7045685f8826f4d6a6761ffbf1e0b2489582c8 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 10 Feb 2019 11:57:34 -0800 Subject: [PATCH 2/4] Update to one-liner --- .../catalyst/optimizer/CostBasedJoinReorder.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 4123238e4ab5..4a7ef1ac6b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -51,7 +51,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } // After reordering is finished, convert OrderedJoin back to Join. result transform { - case OrderedJoin(left, right, jt, cond, hint) => Join(left, right, jt, cond, hint) + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) } } } @@ -77,13 +77,13 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => + case Join(left, right, _: InnerLike, Some(cond), _) => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), hint)) - if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) + if projectList.forall(_.isInstanceOf[Attribute]) => extractInnerJoins(j) case _ => (Seq(plan), Set()) @@ -91,10 +91,10 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond), hint) => + case j @ Join(left, right, jt: InnerLike, Some(cond), JoinHint.NONE) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) - OrderedJoin(replacedLeft, replacedRight, jt, Some(cond), hint) + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => @@ -107,8 +107,7 @@ case class OrderedJoin( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression], - hint: JoinHint) extends BinaryNode { + condition: Option[Expression]) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output } From 534602bfd23f2ef79658fc536597cc38ef7fa7d9 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 11 Feb 2019 13:16:12 -0600 Subject: [PATCH 3/4] Revert "Update to one-liner" This reverts commit eb7045685f8826f4d6a6761ffbf1e0b2489582c8. --- .../catalyst/optimizer/CostBasedJoinReorder.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 4a7ef1ac6b80..4123238e4ab5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -51,7 +51,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } // After reordering is finished, convert OrderedJoin back to Join. result transform { - case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) + case OrderedJoin(left, right, jt, cond, hint) => Join(left, right, jt, cond, hint) } } } @@ -77,13 +77,13 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond), _) => + case Join(left, right, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) - if projectList.forall(_.isInstanceOf[Attribute]) => + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), hint)) + if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => extractInnerJoins(j) case _ => (Seq(plan), Set()) @@ -91,10 +91,10 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond), JoinHint.NONE) => + case j @ Join(left, right, jt: InnerLike, Some(cond), hint) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) - OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond), hint) case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => @@ -107,7 +107,8 @@ case class OrderedJoin( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + condition: Option[Expression], + hint: JoinHint) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output } From d7c156794ce40838339053b7ca327cb84de411c3 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 11 Feb 2019 14:41:22 -0600 Subject: [PATCH 4/4] fix --- .../optimizer/CostBasedJoinReorder.scala | 23 ++++++++--------- .../catalyst/optimizer/JoinReorderSuite.scala | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 4123238e4ab5..f92d8f5b8e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -43,15 +43,15 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. // Avoid reordering if a join hint is present. - case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => + case j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE) => reorder(j, j.output) - case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint)) - if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) + if projectList.forall(_.isInstanceOf[Attribute]) => reorder(p, p.output) } // After reordering is finished, convert OrderedJoin back to Join. result transform { - case OrderedJoin(left, right, jt, cond, hint) => Join(left, right, jt, cond, hint) + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) } } } @@ -77,13 +77,13 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => + case Join(left, right, _: InnerLike, Some(cond), JoinHint.NONE) => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), hint)) - if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) + if projectList.forall(_.isInstanceOf[Attribute]) => extractInnerJoins(j) case _ => (Seq(plan), Set()) @@ -91,11 +91,11 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond), hint) => + case j @ Join(left, right, jt: InnerLike, Some(cond), JoinHint.NONE) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) - OrderedJoin(replacedLeft, replacedRight, jt, Some(cond), hint) - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan @@ -107,8 +107,7 @@ case class OrderedJoin( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression], - hint: JoinHint) extends BinaryNode { + condition: Option[Expression]) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index f1da0a8e865b..18516ee7872a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -312,6 +312,14 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) assertEqualPlans(originalPlan2, originalPlan2) + + val originalPlan3 = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4).hint("broadcast") + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100"))) + + assertEqualPlans(originalPlan3, originalPlan3) } test("reorder below and above the hint node") { @@ -342,6 +350,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(t4.hint("broadcast")) assertEqualPlans(originalPlan2, bestPlan2) + + val originalPlan3 = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast") + .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100"))) + + val bestPlan3 = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) + .hint("broadcast") + .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100"))) + + assertEqualPlans(originalPlan3, bestPlan3) } private def assertEqualPlans(