Skip to content

Commit 888f8f0

Browse files
gaoyajun02gaoyajun02
authored andcommitted
[SPARK-36339][SQL] References to grouping that not part of aggregation should be replaced
### What changes were proposed in this pull request? Currently, references to grouping sets are reported as errors after aggregated expressions, e.g. ``` SELECT count(name) c, name FROM VALUES ('Alice'), ('Bob') people(name) GROUP BY name GROUPING SETS(name); ``` Error in query: expression 'people.`name`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;; ### Why are the changes needed? Fix the map anonymous function in the constructAggregateExprs function does not use underscores to avoid ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Closes #33574 from gaoyajun02/SPARK-36339. Lead-authored-by: gaoyajun02 <[email protected]> Co-authored-by: gaoyajun02 <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7bb53b8 commit 888f8f0

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,15 +580,15 @@ class Analyzer(override val catalogManager: CatalogManager)
580580
aggregations: Seq[NamedExpression],
581581
groupByAliases: Seq[Alias],
582582
groupingAttrs: Seq[Expression],
583-
gid: Attribute): Seq[NamedExpression] = aggregations.map {
583+
gid: Attribute): Seq[NamedExpression] = aggregations.map { agg =>
584584
// collect all the found AggregateExpression, so we can check an expression is part of
585585
// any AggregateExpression or not.
586586
val aggsBuffer = ArrayBuffer[Expression]()
587587
// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
588588
def isPartOfAggregation(e: Expression): Boolean = {
589589
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
590590
}
591-
replaceGroupingFunc(_, groupByExprs, gid).transformDown {
591+
replaceGroupingFunc(agg, groupByExprs, gid).transformDown {
592592
// AggregateExpression should be computed on the unmodified value of its argument
593593
// expressions, so we should not replace any references to grouping expression
594594
// inside it.

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3405,6 +3405,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
34053405
}
34063406
}
34073407

3408+
test("SPARK-36339: References to grouping attributes should be replaced") {
3409+
withTempView("t") {
3410+
Seq("a", "a", "b").toDF("x").createOrReplaceTempView("t")
3411+
checkAnswer(
3412+
sql(
3413+
"""
3414+
|select count(x) c, x from t
3415+
|group by x grouping sets(x)
3416+
""".stripMargin),
3417+
Seq(Row(2, "a"), Row(1, "b")))
3418+
}
3419+
}
3420+
34083421
test("SPARK-31166: UNION map<null, null> and other maps should not fail") {
34093422
checkAnswer(
34103423
sql("(SELECT map()) UNION ALL (SELECT map(1, 2))"),

0 commit comments

Comments
 (0)