Skip to content

Commit f51e31d

Browse files
committed
address review comments
1 parent 470d682 commit f51e31d

11 files changed

Lines changed: 208 additions & 30 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2253,7 +2253,7 @@ class Analyzer(
22532253
if left.resolved && right.resolved && j.duplicateResolved =>
22542254
commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint)
22552255
case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
2256-
if j.resolvedExceptNatural =>
2256+
if j.resolvedExceptNatural =>
22572257
// find common column names from both sides
22582258
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
22592259
commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,40 @@ import org.apache.spark.sql.internal.SQLConf
3131
* Cost-based join reorder.
3232
* We may have several join reorder algorithms in the future. This class is the entry of these
3333
* algorithms, and chooses which one to use.
34+
*
35+
* Note that join strategy hints, e.g. the broadcast hint, do not interfere with the reordering.
36+
* Such hints will be applied on the equivalent counterparts (i.e., join between the same relations
37+
* regardless of the join order) of the original nodes after reordering.
38+
* For example, the plan before reordering is like:
39+
*
40+
* Join
41+
* / \
42+
* Hint1 t4
43+
* /
44+
* Join
45+
* / \
46+
* Join t3
47+
* / \
48+
* Hint2 t2
49+
* /
50+
* t1
51+
*
52+
* The original join order as illustrated above is "((t1 JOIN t2) JOIN t3) JOIN t4", and after
53+
* reordering, the new join order is "((t1 JOIN t3) JOIN t2) JOIN t4", so the new plan will be like:
54+
*
55+
* Join
56+
* / \
57+
* Hint1 t4
58+
* /
59+
* Join
60+
* / \
61+
* Join t2
62+
* / \
63+
* t1 t3
64+
*
65+
* "Hint1" is applied on "(t1 JOIN t3) JOIN t2" as it is equivalent to the original hinted node,
66+
* "(t1 JOIN t2) JOIN t3"; while "Hint2" has disappeared from the new plan since there is no
67+
* equivalent node to "t1 JOIN t2".
3468
*/
3569
object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
3670

@@ -40,9 +74,8 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
4074
if (!conf.cboEnabled || !conf.joinReorderEnabled) {
4175
plan
4276
} else {
43-
// Use a map to track the hints on the join items. If a join relation turns out unchanged
44-
// at the end of the join reorder, we can apply the original hint back to it if any.
45-
val hintMap = new mutable.HashMap[LogicalPlan, HintInfo]
77+
// Use a map to track the hints on the join items.
78+
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
4679
val result = plan transformDown {
4780
// Start reordering with a joinable item, which is an InnerLike join with conditions.
4881
case j @ Join(_, _, _: InnerLike, Some(cond), _) =>
@@ -52,20 +85,18 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
5285
reorder(p, p.output, hintMap)
5386
}
5487
// After reordering is finished, convert OrderedJoin back to Join.
55-
// Note that this needs to be done bottom-up to make sure the hints can be mapped to any
56-
// unchanged relations.
57-
result transformUp {
88+
result transform {
5889
case OrderedJoin(left, right, jt, cond) =>
5990
Join(left, right, jt, cond,
60-
JoinHint(hintMap.get(left), hintMap.get(right)))
91+
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
6192
}
6293
}
6394
}
6495

6596
private def reorder(
6697
plan: LogicalPlan,
6798
output: Seq[Attribute],
68-
hintMap: mutable.HashMap[LogicalPlan, HintInfo]): LogicalPlan = {
99+
hintMap: mutable.HashMap[AttributeSet, HintInfo]): LogicalPlan = {
69100
val (items, conditions) = extractInnerJoins(plan, hintMap)
70101
val result =
71102
// Do reordering if the number of items is appropriate and join conditions exist.
@@ -86,11 +117,11 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
86117
*/
87118
private def extractInnerJoins(
88119
plan: LogicalPlan,
89-
hintMap: mutable.HashMap[LogicalPlan, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = {
120+
hintMap: mutable.HashMap[AttributeSet, HintInfo]): (Seq[LogicalPlan], Set[Expression]) = {
90121
plan match {
91122
case Join(left, right, _: InnerLike, Some(cond), hint) =>
92-
hint.leftHint.map(hintMap.put(left, _))
93-
hint.rightHint.map(hintMap.put(right, _))
123+
hint.leftHint.foreach(hintMap.put(left.outputSet, _))
124+
hint.rightHint.foreach(hintMap.put(right.outputSet, _))
94125
val (leftPlans, leftConditions) = extractInnerJoins(left, hintMap)
95126
val (rightPlans, rightConditions) = extractInnerJoins(right, hintMap)
96127
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
5151
input: Seq[(LogicalPlan, InnerLike)],
5252
conditions: Seq[Expression],
5353
leftPlans: Seq[LogicalPlan],
54-
hintMap: Map[Seq[LogicalPlan], HintInfo]): LogicalPlan = {
54+
hintMap: Map[AttributeSet, HintInfo]): LogicalPlan = {
5555
assert(input.size >= 2)
5656
if (input.size == 2) {
5757
val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
@@ -61,7 +61,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
6161
case (_, _) => Cross
6262
}
6363
val join = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
64-
JoinHint(hintMap.get(leftPlans), hintMap.get(Seq(right))))
64+
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
6565
if (others.nonEmpty) {
6666
Filter(others.reduceLeft(And), join)
6767
} else {
@@ -85,7 +85,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
8585
val (joinConditions, others) = conditions.partition(
8686
e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
8787
val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And),
88-
JoinHint(hintMap.get(leftPlans), hintMap.get(Seq(right))))
88+
JoinHint(hintMap.get(left.outputSet), hintMap.get(right.outputSet)))
8989

9090
// should not have reference to same logical plan
9191
createOrderedJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
169169
*/
170170
def flattenJoin(
171171
plan: LogicalPlan,
172-
hintMap: mutable.HashMap[Seq[LogicalPlan], HintInfo],
172+
hintMap: mutable.HashMap[AttributeSet, HintInfo],
173173
parentJoinType: InnerLike = Inner)
174174
: (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match {
175175
case Join(left, right, joinType: InnerLike, cond, hint) =>
176176
val (plans, conditions) = flattenJoin(left, hintMap, joinType)
177-
hint.leftHint.map(hintMap.put(plans.map(_._1), _))
178-
hint.rightHint.map(hintMap.put(Seq(right), _))
177+
hint.leftHint.map(hintMap.put(left.outputSet, _))
178+
hint.rightHint.map(hintMap.put(right.outputSet, _))
179179
(plans ++ Seq((right, joinType)), conditions ++
180180
cond.toSeq.flatMap(splitConjunctivePredicates))
181181
case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, _)) =>
@@ -186,14 +186,14 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
186186
}
187187

188188
def unapply(plan: LogicalPlan)
189-
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[Seq[LogicalPlan], HintInfo])]
189+
: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression], Map[AttributeSet, HintInfo])]
190190
= plan match {
191191
case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, _)) =>
192-
val hintMap = new mutable.HashMap[Seq[LogicalPlan], HintInfo]
192+
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
193193
val flattened = flattenJoin(f, hintMap)
194194
Some((flattened._1, flattened._2, hintMap.toMap))
195195
case j @ Join(_, _, joinType, _, _) =>
196-
val hintMap = new mutable.HashMap[Seq[LogicalPlan], HintInfo]
196+
val hintMap = new mutable.HashMap[AttributeSet, HintInfo]
197197
val flattened = flattenJoin(j, hintMap)
198198
Some((flattened._1, flattened._2, hintMap.toMap))
199199
case _ => None

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,13 @@ case class Join(
355355
// Ignore hint for canonicalization
356356
protected override def doCanonicalize(): LogicalPlan =
357357
super.doCanonicalize().asInstanceOf[Join].copy(hint = JoinHint.NONE)
358+
359+
// Do not include an empty join hint in string description
360+
protected override def stringArgs: Iterator[Any] = super.stringArgs.filter { e =>
361+
(!e.isInstanceOf[JoinHint]
362+
|| e.asInstanceOf[JoinHint].leftHint.isDefined
363+
|| e.asInstanceOf[JoinHint].rightHint.isDefined)
364+
}
358365
}
359366

360367
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
4949
* Hint that is associated with a [[Join]] node, with [[HintInfo]] on its left child and on its
5050
* right child respectively.
5151
*/
52-
case class JoinHint(
53-
leftHint: Option[HintInfo],
54-
rightHint: Option[HintInfo]) {
52+
case class JoinHint(leftHint: Option[HintInfo], rightHint: Option[HintInfo]) {
5553

5654
override def toString: String = {
5755
Seq(
@@ -65,6 +63,12 @@ object JoinHint {
6563
val NONE = JoinHint(None, None)
6664
}
6765

66+
/**
67+
* The hint attributes to be applied on a specific node.
68+
*
69+
* @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join
70+
* strategy and the node with this hint is preferred to be the build side.
71+
*/
6872
case class HintInfo(broadcast: Boolean = false) {
6973

7074
override def toString: String = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
454454
case Some(serde) => table.identifier :: serde :: Nil
455455
case _ => table.identifier :: Nil
456456
}
457-
case hint: JoinHint if hint.leftHint.isEmpty && hint.rightHint.isEmpty => Nil
458457
case other => other :: Nil
459458
}.mkString(", ")
460459

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
2323
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest}
24-
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan}
24+
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, EliminateResolvedHint, LocalRelation, LogicalPlan}
2525
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2626
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
2727
import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}
@@ -31,6 +31,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
3131

3232
object Optimize extends RuleExecutor[LogicalPlan] {
3333
val batches =
34+
Batch("Resolve Hints", Once,
35+
EliminateResolvedHint) ::
3436
Batch("Operator Optimizations", FixedPoint(100),
3537
CombineFilters,
3638
PushDownPredicate,
@@ -42,6 +44,12 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
4244
CostBasedJoinReorder) :: Nil
4345
}
4446

47+
object ResolveHints extends RuleExecutor[LogicalPlan] {
48+
val batches =
49+
Batch("Resolve Hints", Once,
50+
EliminateResolvedHint) :: Nil
51+
}
52+
4553
var originalConfCBOEnabled = false
4654
var originalConfJoinReorderEnabled = false
4755

@@ -284,12 +292,85 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
284292
assertEqualPlans(originalPlan, bestPlan)
285293
}
286294

295+
test("hints preservation") {
296+
// Apply hints if we find an equivalent node in the new plan, otherwise discard them.
297+
val originalPlan =
298+
t1.join(t2.hint("broadcast")).hint("broadcast").join(t4.join(t3).hint("broadcast"))
299+
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
300+
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
301+
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
302+
303+
val bestPlan =
304+
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
305+
.hint("broadcast")
306+
.join(
307+
t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
308+
.hint("broadcast"),
309+
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
310+
311+
assertEqualPlans(originalPlan, bestPlan)
312+
313+
val originalPlan2 =
314+
t1.join(t2).hint("broadcast").join(t3).hint("broadcast").join(t4.hint("broadcast"))
315+
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
316+
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
317+
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
318+
319+
val bestPlan2 =
320+
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
321+
.hint("broadcast")
322+
.join(
323+
t4.hint("broadcast")
324+
.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
325+
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
326+
.select(outputsOf(t1, t2, t3, t4): _*)
327+
328+
assertEqualPlans(originalPlan2, bestPlan2)
329+
330+
val originalPlan3 =
331+
t1.join(t4).hint("broadcast")
332+
.join(t2.hint("broadcast")).hint("broadcast")
333+
.join(t3.hint("broadcast"))
334+
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
335+
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
336+
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
337+
338+
val bestPlan3 =
339+
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
340+
.join(
341+
t4.join(t3.hint("broadcast"),
342+
Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
343+
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
344+
.select(outputsOf(t1, t4, t2, t3): _*)
345+
346+
assertEqualPlans(originalPlan3, bestPlan3)
347+
348+
val originalPlan4 =
349+
t2.hint("broadcast")
350+
.join(t4).hint("broadcast")
351+
.join(t3.hint("broadcast")).hint("broadcast")
352+
.join(t1)
353+
.where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
354+
(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
355+
(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
356+
357+
val bestPlan4 =
358+
t1.join(t2.hint("broadcast"), Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
359+
.join(
360+
t4.join(t3.hint("broadcast"),
361+
Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
362+
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
363+
.select(outputsOf(t2, t4, t3, t1): _*)
364+
365+
assertEqualPlans(originalPlan4, bestPlan4)
366+
}
367+
287368
private def assertEqualPlans(
288369
originalPlan: LogicalPlan,
289370
groundTruthBestPlan: LogicalPlan): Unit = {
290371
val analyzed = originalPlan.analyze
291372
val optimized = Optimize.execute(analyzed)
292-
val expected = groundTruthBestPlan.analyze
373+
val expected = ResolveHints.execute(groundTruthBestPlan.analyze)
293374

294375
assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect
295376
assert(analyzed.sameOutput(optimized))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,10 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite =>
165165
private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
166166
(plan1, plan2) match {
167167
case (j1: Join, j2: Join) =>
168-
(sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
169-
(sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
168+
(sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)
169+
&& j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) ||
170+
(sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)
171+
&& j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint)
170172
case (p1: Project, p2: Project) =>
171173
p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child)
172174
case _ =>

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod
2727
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2828
import org.apache.spark.sql.catalyst.TableIdentifier
2929
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
30+
import org.apache.spark.sql.catalyst.plans.logical.Join
3031
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
3132
import org.apache.spark.sql.execution.columnar._
3233
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
34+
import org.apache.spark.sql.functions._
3335
import org.apache.spark.sql.internal.SQLConf
3436
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
3537
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -925,4 +927,23 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
925927
}
926928
}
927929
}
930+
931+
test("Cache should respect the broadcast hint") {
932+
val df = broadcast(spark.range(1000)).cache()
933+
val df2 = spark.range(1000).cache()
934+
df.count()
935+
df2.count()
936+
937+
// Test the broadcast hint.
938+
val joinPlan = df.join(df2, "id").queryExecution.optimizedPlan
939+
val hint = joinPlan.collect {
940+
case Join(_, _, _, _, hint) => hint
941+
}
942+
assert(hint.size == 1)
943+
assert(hint(0).leftHint.get.broadcast)
944+
assert(hint(0).rightHint.isEmpty)
945+
946+
// Clean-up
947+
df.unpersist()
948+
}
928949
}

0 commit comments

Comments
 (0)