From f369a62b7db9ef6b4795320df5d2c9f88cbcfdea Mon Sep 17 00:00:00 2001 From: gaoyajun02 Date: Thu, 29 Jul 2021 15:21:42 +0800 Subject: [PATCH 1/2] [SPARK-36339] References to grouping that not part of aggregation should be replaced --- .../sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f031f0816db1..1a70588c2119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -580,7 +580,7 @@ class Analyzer(override val catalogManager: CatalogManager) aggregations: Seq[NamedExpression], groupByAliases: Seq[Alias], groupingAttrs: Seq[Expression], - gid: Attribute): Seq[NamedExpression] = aggregations.map { + gid: Attribute): Seq[NamedExpression] = aggregations.map { agg => // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. val aggsBuffer = ArrayBuffer[Expression]() @@ -588,7 +588,7 @@ class Analyzer(override val catalogManager: CatalogManager) def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } - replaceGroupingFunc(_, groupByExprs, gid).transformDown { + replaceGroupingFunc(agg, groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ed3b47947a5d..05563aad76ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3405,6 +3405,26 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-36339: References to grouping attributes should be replaced") { + withTempView("t") { + Seq("a", "a", "b").toDF("x").createOrReplaceTempView("t") + checkAnswer( + sql( + """ + |select x, count(x) c from t + |group by x grouping sets(x) + """.stripMargin), + Seq(Row("a", 2), Row("b", 1))) + checkAnswer( + sql( + """ + |select count(x) c, x from t + |group by x grouping sets(x) + """.stripMargin), + Seq(Row(2, "a"), Row(1, "b"))) + } + } + test("SPARK-31166: UNION map and other maps should not fail") { checkAnswer( sql("(SELECT map()) UNION ALL (SELECT map(1, 2))"), From 321374d87a1f50e458be42e55b08472befa846e3 Mon Sep 17 00:00:00 2001 From: gaoyajun02 Date: Thu, 5 Aug 2021 23:00:34 +0800 Subject: [PATCH 2/2] Remove irrelevant queries in unit tests --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05563aad76ed..032ddbbcebf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3408,13 +3408,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark test("SPARK-36339: References to grouping attributes should be replaced") { withTempView("t") { Seq("a", "a", "b").toDF("x").createOrReplaceTempView("t") - checkAnswer( - sql( - """ - |select x, count(x) c from t - |group by x grouping sets(x) - """.stripMargin), - Seq(Row("a", 2), Row("b", 1))) checkAnswer( sql( """