Skip to content

Commit a34e2af

Browse files
committed
DistinctKeyVisitor
1 parent 8ac519f commit a34e2af

16 files changed

Lines changed: 2809 additions & 2864 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic
21-
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet}
21+
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet}
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2424
import org.apache.spark.sql.catalyst.rules.Rule
2525
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
2626

@@ -48,9 +48,9 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
4848
newAggregate
4949
}
5050

51-
case agg @ Aggregate(groupingExps, _, j: Join) if agg.groupOnly &&
52-
j.distinctAttributes.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
53-
Project(agg.output, j)
51+
case agg @ Aggregate(groupingExps, _, child) if agg.groupOnly && child.deterministic &&
52+
child.distinctKeys.exists(_.subsetOf(AttributeSet(groupingExps))) =>
53+
Project(agg.aggregateExpressions, child)
5454
}
5555

5656
private def isLowerRedundant(upper: Aggregate, lower: Aggregate): Boolean = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctAttributesVisitor.scala

Lines changed: 0 additions & 100 deletions
This file was deleted.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeSet, ExpressionSet, NamedExpression}
21+
import org.apache.spark.sql.catalyst.plans.LeftExistence
22+
23+
/**
24+
* A visitor pattern for traversing a [[LogicalPlan]] tree and propagate the distinct attributes.
25+
*/
26+
object DistinctKeyVisitor extends LogicalPlanVisitor[Set[AttributeSet]] {
27+
28+
private def projectDistinctKeys(
29+
keys: Set[ExpressionSet], projectList: Seq[NamedExpression]): Set[AttributeSet] = {
30+
val expressions = keys.flatMap(_.toSet)
31+
projectList.filter {
32+
case a: Alias => expressions.exists(_.semanticEquals(a.child))
33+
case ne => expressions.exists(_.semanticEquals(ne))
34+
}.toSet.subsets(keys.map(_.size).min).filter { s =>
35+
val references = s.map {
36+
case a: Alias => a.child
37+
case ne => ne
38+
}
39+
keys.exists(_.equals(ExpressionSet(references)))
40+
}.map(s => AttributeSet(s.map(_.toAttribute))).toSet
41+
}
42+
43+
override def default(p: LogicalPlan): Set[AttributeSet] = Set.empty[AttributeSet]
44+
45+
override def visitAggregate(p: Aggregate): Set[AttributeSet] = {
46+
val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a
47+
projectDistinctKeys(Set(groupingExps), p.aggregateExpressions)
48+
}
49+
50+
override def visitDistinct(p: Distinct): Set[AttributeSet] = {
51+
Set(p.outputSet)
52+
}
53+
54+
override def visitExcept(p: Except): Set[AttributeSet] =
55+
if (!p.isAll && p.deterministic) Set(p.outputSet) else default(p)
56+
57+
override def visitExpand(p: Expand): Set[AttributeSet ] = default(p)
58+
59+
override def visitFilter(p: Filter): Set[AttributeSet ] = p.child.distinctKeys
60+
61+
override def visitGenerate(p: Generate): Set[AttributeSet ] = default(p)
62+
63+
override def visitGlobalLimit(p: GlobalLimit): Set[AttributeSet ] = p.child.distinctKeys
64+
65+
override def visitIntersect(p: Intersect): Set[AttributeSet ] = {
66+
if (!p.isAll && p.deterministic) Set(p.outputSet) else default(p)
67+
}
68+
69+
override def visitJoin(p: Join): Set[AttributeSet] = {
70+
p.joinType match {
71+
case LeftExistence(_) => p.left.distinctKeys
72+
case _ => default(p)
73+
}
74+
}
75+
76+
override def visitLocalLimit(p: LocalLimit): Set[AttributeSet] = p.child.distinctKeys
77+
78+
override def visitPivot(p: Pivot): Set[AttributeSet] = default(p)
79+
80+
override def visitProject(p: Project): Set[AttributeSet] = {
81+
if (p.child.distinctKeys.nonEmpty) {
82+
projectDistinctKeys(p.child.distinctKeys.map(ExpressionSet(_)), p.projectList)
83+
} else {
84+
default(p)
85+
}
86+
}
87+
88+
override def visitRepartition(p: Repartition): Set[AttributeSet] = p.child.distinctKeys
89+
90+
override def visitRepartitionByExpr(p: RepartitionByExpression): Set[AttributeSet] =
91+
p.child.distinctKeys
92+
93+
override def visitSample(p: Sample): Set[AttributeSet] = default(p)
94+
95+
override def visitScriptTransform(p: ScriptTransformation): Set[AttributeSet] = default(p)
96+
97+
override def visitUnion(p: Union): Set[AttributeSet] = default(p)
98+
99+
override def visitWindow(p: Window): Set[AttributeSet] = p.child.distinctKeys
100+
101+
override def visitTail(p: Tail): Set[AttributeSet] = p.child.distinctKeys
102+
103+
override def visitSort(p: Sort): Set[AttributeSet] = p.child.distinctKeys
104+
105+
override def visitRebalancePartitions(p: RebalancePartitions): Set[AttributeSet] =
106+
p.child.distinctKeys
107+
108+
override def visitWithCTE(p: WithCTE): Set[AttributeSet] = default(p)
109+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ abstract class LogicalPlan
3131
extends QueryPlan[LogicalPlan]
3232
with AnalysisHelper
3333
with LogicalPlanStats
34-
with LogicalPlanDistinctAttributes
34+
with LogicalPlanDistinctKeys
3535
with QueryPlanConstraints
3636
with Logging {
3737

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctAttributes.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20-
import org.apache.spark.sql.catalyst.expressions.ExpressionSet
20+
import org.apache.spark.sql.catalyst.expressions.AttributeSet
2121

2222
/**
2323
* A trait to add distinct attributes to [[LogicalPlan]]. For example:
2424
* {{{
25-
* SELECT a, a FROM Tab1 GROUP BY a, b
26-
* // returns a
25+
* SELECT a, b, SUM(c) FROM Tab1 GROUP BY a, b
26+
* // returns a, b
2727
* }}}
2828
*/
29-
trait LogicalPlanDistinctAttributes { self: LogicalPlan =>
30-
def distinctAttributes: Set[ExpressionSet] = {
31-
DistinctAttributesVisitor.visit(self)
32-
}
29+
trait LogicalPlanDistinctKeys { self: LogicalPlan =>
30+
lazy val distinctKeys: Set[AttributeSet] = DistinctKeyVisitor.visit(self)
3331
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,19 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
230230
}
231231
}
232232

233+
test("SPARK-36194: Remove aggregation from aggregation") {
234+
val originalQuery = relation
235+
.groupBy('a)('a, count('b).as("cnt"))
236+
.groupBy('a, 'cnt)('a, 'cnt)
237+
.analyze
238+
val correctAnswer = relation
239+
.groupBy('a)('a, count('b).as("cnt"))
240+
.select('a, 'cnt)
241+
.analyze
242+
val optimized = Optimize.execute(originalQuery)
243+
comparePlans(optimized, correctAnswer)
244+
}
245+
233246
test("SPARK-36194: Negative case: The grouping expressions not same") {
234247
Seq(LeftSemi, LeftAnti).foreach { joinType =>
235248
val originalQuery = x.groupBy('a, 'b)('a, 'b)
@@ -273,4 +286,13 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
273286
comparePlans(optimized, originalQuery.analyze)
274287
}
275288
}
289+
290+
test("SPARK-36194: Negative case: Remove aggregation from contains non-deterministic") {
291+
val query = relation
292+
.groupBy('a)('a, (count('b) + rand(0)).as("cnt"))
293+
.groupBy('a, 'cnt)('a, 'cnt)
294+
.analyze
295+
val optimized = Optimize.execute(query)
296+
comparePlans(optimized, query)
297+
}
276298
}

0 commit comments

Comments
 (0)