Skip to content

Commit c16a66a

Browse files
committed
[SPARK-36194][SQL] Add a logical plan visitor to propagate the distinct attributes
### What changes were proposed in this pull request? 1. This pr add a new logical plan visitor named `DistinctKeyVisitor` to find out all the distinct attributes in current logical plan. For example: ```scala spark.sql("CREATE TABLE t(a int, b int, c int) using parquet") spark.sql("SELECT a, b, a % 10, max(c), sum(b) FROM t GROUP BY a, b").queryExecution.analyzed.distinctKeys ``` The output is: {a#1, b#2}. 2. Enhance `RemoveRedundantAggregates` to remove the aggregation if it is groupOnly and the child can guarantee distinct. For example: ```sql set spark.sql.autoBroadcastJoinThreshold=-1; -- avoid PushDownLeftSemiAntiJoin create table t1 using parquet as select id a, id as b from range(10); create table t2 using parquet as select id as a, id as b from range(8); select t11.a, t11.b from (select distinct a, b from t1) t11 left semi join t2 on (t11.a = t2.a) group by t11.a, t11.b; ``` Before this PR: ``` == Optimized Logical Plan == Aggregate [a#6L, b#7L], [a#6L, b#7L], Statistics(sizeInBytes=1492.0 B) +- Join LeftSemi, (a#6L = a#8L), Statistics(sizeInBytes=1492.0 B) :- Aggregate [a#6L, b#7L], [a#6L, b#7L], Statistics(sizeInBytes=1492.0 B) : +- Filter isnotnull(a#6L), Statistics(sizeInBytes=1492.0 B) : +- Relation default.t1[a#6L,b#7L] parquet, Statistics(sizeInBytes=1492.0 B) +- Project [a#8L], Statistics(sizeInBytes=984.0 B) +- Filter isnotnull(a#8L), Statistics(sizeInBytes=1476.0 B) +- Relation default.t2[a#8L,b#9L] parquet, Statistics(sizeInBytes=1476.0 B) ``` After this PR: ``` == Optimized Logical Plan == Join LeftSemi, (a#6L = a#8L), Statistics(sizeInBytes=1492.0 B) :- Aggregate [a#6L, b#7L], [a#6L, b#7L], Statistics(sizeInBytes=1492.0 B) : +- Filter isnotnull(a#6L), Statistics(sizeInBytes=1492.0 B) : +- Relation default.t1[a#6L,b#7L] parquet, Statistics(sizeInBytes=1492.0 B) +- Project [a#8L], Statistics(sizeInBytes=984.0 B) +- Filter isnotnull(a#8L), Statistics(sizeInBytes=1476.0 B) +- Relation default.t2[a#8L,b#9L] parquet, Statistics(sizeInBytes=1476.0 B) ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test and TPC-DS benchmark test. SQL | Before this PR(Seconds) | After this PR(Seconds) -- | -- | -- q14a | 206  | 193 q38 | 59 | 41 q87 | 127 | 113 Closes #35779 from wangyum/SPARK-36194. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Yuming Wang <[email protected]>
1 parent 0005b41 commit c16a66a

File tree

31 files changed

+4770
-4643
lines changed

31 files changed

+4770
-4643
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic
21-
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet}
21+
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet}
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
2626

@@ -47,6 +47,10 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
4747
} else {
4848
newAggregate
4949
}
50+
51+
case agg @ Aggregate(groupingExps, _, child)
52+
if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
53+
Project(agg.aggregateExpressions, child)
5054
}
5155

5256
private def isLowerRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionSet, NamedExpression}
21+
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
22+
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemiOrAnti, RightOuter}
23+
24+
/**
25+
* A visitor pattern for traversing a [[LogicalPlan]] tree and propagate the distinct attributes.
26+
*/
27+
object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
28+
29+
private def projectDistinctKeys(
30+
keys: Set[ExpressionSet], projectList: Seq[NamedExpression]): Set[ExpressionSet] = {
31+
val outputSet = ExpressionSet(projectList.map(_.toAttribute))
32+
val aliases = projectList.filter(_.isInstanceOf[Alias])
33+
if (aliases.isEmpty) {
34+
keys.filter(_.subsetOf(outputSet))
35+
} else {
36+
val aliasedDistinctKeys = keys.map { expressionSet =>
37+
expressionSet.map { expression =>
38+
expression transform {
39+
case expr: Expression =>
40+
// TODO: Expand distinctKeys for redundant aliases on the same expression
41+
aliases
42+
.collectFirst { case a: Alias if a.child.semanticEquals(expr) => a.toAttribute }
43+
.getOrElse(expr)
44+
}
45+
}
46+
}
47+
aliasedDistinctKeys.collect {
48+
case es: ExpressionSet if es.subsetOf(outputSet) => ExpressionSet(es)
49+
} ++ keys.filter(_.subsetOf(outputSet))
50+
}.filter(_.nonEmpty)
51+
}
52+
53+
override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet]
54+
55+
override def visitAggregate(p: Aggregate): Set[ExpressionSet] = {
56+
val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a
57+
projectDistinctKeys(Set(groupingExps), p.aggregateExpressions)
58+
}
59+
60+
override def visitDistinct(p: Distinct): Set[ExpressionSet] = Set(ExpressionSet(p.output))
61+
62+
override def visitExcept(p: Except): Set[ExpressionSet] =
63+
if (!p.isAll) Set(ExpressionSet(p.output)) else default(p)
64+
65+
override def visitExpand(p: Expand): Set[ExpressionSet] = default(p)
66+
67+
override def visitFilter(p: Filter): Set[ExpressionSet] = p.child.distinctKeys
68+
69+
override def visitGenerate(p: Generate): Set[ExpressionSet] = default(p)
70+
71+
override def visitGlobalLimit(p: GlobalLimit): Set[ExpressionSet] = {
72+
p.maxRows match {
73+
case Some(value) if value <= 1 => Set(ExpressionSet(p.output))
74+
case _ => p.child.distinctKeys
75+
}
76+
}
77+
78+
override def visitIntersect(p: Intersect): Set[ExpressionSet] = {
79+
if (!p.isAll) Set(ExpressionSet(p.output)) else default(p)
80+
}
81+
82+
override def visitJoin(p: Join): Set[ExpressionSet] = {
83+
p match {
84+
case Join(_, _, LeftSemiOrAnti(_), _, _) =>
85+
p.left.distinctKeys
86+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, _)
87+
if left.distinctKeys.nonEmpty || right.distinctKeys.nonEmpty =>
88+
val rightJoinKeySet = ExpressionSet(rightKeys)
89+
val leftJoinKeySet = ExpressionSet(leftKeys)
90+
joinType match {
91+
case Inner if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) &&
92+
right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) =>
93+
left.distinctKeys ++ right.distinctKeys
94+
case Inner | LeftOuter if right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) =>
95+
p.left.distinctKeys
96+
case Inner | RightOuter if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) =>
97+
p.right.distinctKeys
98+
case _ =>
99+
default(p)
100+
}
101+
case _ => default(p)
102+
}
103+
}
104+
105+
override def visitLocalLimit(p: LocalLimit): Set[ExpressionSet] = p.child.distinctKeys
106+
107+
override def visitPivot(p: Pivot): Set[ExpressionSet] = default(p)
108+
109+
override def visitProject(p: Project): Set[ExpressionSet] = {
110+
if (p.child.distinctKeys.nonEmpty) {
111+
projectDistinctKeys(p.child.distinctKeys, p.projectList)
112+
} else {
113+
default(p)
114+
}
115+
}
116+
117+
override def visitRepartition(p: Repartition): Set[ExpressionSet] = p.child.distinctKeys
118+
119+
override def visitRepartitionByExpr(p: RepartitionByExpression): Set[ExpressionSet] =
120+
p.child.distinctKeys
121+
122+
override def visitSample(p: Sample): Set[ExpressionSet] = {
123+
if (!p.withReplacement) p.child.distinctKeys else default(p)
124+
}
125+
126+
override def visitScriptTransform(p: ScriptTransformation): Set[ExpressionSet] = default(p)
127+
128+
override def visitUnion(p: Union): Set[ExpressionSet] = default(p)
129+
130+
override def visitWindow(p: Window): Set[ExpressionSet] = p.child.distinctKeys
131+
132+
override def visitTail(p: Tail): Set[ExpressionSet] = p.child.distinctKeys
133+
134+
override def visitSort(p: Sort): Set[ExpressionSet] = p.child.distinctKeys
135+
136+
override def visitRebalancePartitions(p: RebalancePartitions): Set[ExpressionSet] =
137+
p.child.distinctKeys
138+
139+
override def visitWithCTE(p: WithCTE): Set[ExpressionSet] = p.plan.distinctKeys
140+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ abstract class LogicalPlan
3131
extends QueryPlan[LogicalPlan]
3232
with AnalysisHelper
3333
with LogicalPlanStats
34+
with LogicalPlanDistinctKeys
3435
with QueryPlanConstraints
3536
with Logging {
3637

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
import org.apache.spark.sql.catalyst.expressions.ExpressionSet
21+
import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED
22+
23+
/**
24+
* A trait to add distinct attributes to [[LogicalPlan]]. For example:
25+
* {{{
26+
* SELECT a, b, SUM(c) FROM Tab1 GROUP BY a, b
27+
* // returns a, b
28+
* }}}
29+
*/
30+
trait LogicalPlanDistinctKeys { self: LogicalPlan =>
31+
lazy val distinctKeys: Set[ExpressionSet] = {
32+
if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty
33+
}
34+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,15 @@ object SQLConf {
744744
.booleanConf
745745
.createWithDefault(true)
746746

747+
val PROPAGATE_DISTINCT_KEYS_ENABLED =
748+
buildConf("spark.sql.optimizer.propagateDistinctKeys.enabled")
749+
.internal()
750+
.doc("When true, the query optimizer will propagate a set of distinct attributes from the " +
751+
"current node and use it to optimize query.")
752+
.version("3.3.0")
753+
.booleanConf
754+
.createWithDefault(true)
755+
747756
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
748757
.internal()
749758
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +

0 commit comments

Comments
 (0)