Skip to content

Commit 2634588

Browse files
committed
[SPARK-32816][SQL] Fix analyzer bug when aggregating multiple distinct DECIMAL columns
This PR fixes a conflict between `RewriteDistinctAggregates` and `DecimalAggregates`. In some cases, `DecimalAggregates` will wrap the decimal column to `UnscaledValue` using different rules for different aggregates. This means, same distinct column with different aggregates will change to different distinct columns after `DecimalAggregates`. For example: `avg(distinct decimal_col), sum(distinct decimal_col)` may change to `avg(distinct UnscaledValue(decimal_col)), sum(distinct decimal_col)` We assume after `RewriteDistinctAggregates`, there will be at most one distinct column in aggregates, but `DecimalAggregates` breaks this assumption. To fix this, we have to switch the order of these two rules. bug fix no added test cases Closes #29673 from linhongliu-db/SPARK-32816. Authored-by: Linhong Liu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 40ef5c9)
1 parent 698ac6a commit 2634588

3 files changed

Lines changed: 16 additions & 2 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
134134
RewriteNonCorrelatedExists,
135135
ComputeCurrentTime,
136136
GetCurrentDatabase(catalogManager),
137-
RewriteDistinctAggregates,
138137
ReplaceDeduplicateWithAggregate) ::
139138
//////////////////////////////////////////////////////////////////////////////////////////
140139
// Optimizer rules start here
@@ -185,6 +184,10 @@ abstract class Optimizer(catalogManager: CatalogManager)
185184
EliminateSorts) :+
186185
Batch("Decimal Optimizations", fixedPoint,
187186
DecimalAggregates) :+
187+
// This batch must run after "Decimal Optimizations", as that one may change the
188+
// aggregate distinct column
189+
Batch("Distinct Aggregate Rewrite", Once,
190+
RewriteDistinctAggregates) :+
188191
Batch("Object Expressions Optimization", fixedPoint,
189192
EliminateMapObjects,
190193
CombineTypedFilters,

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,6 @@ SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L;
166166
SELECT count(*) FROM test_agg WHERE count(*) > 1L;
167167
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L;
168168
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1;
169+
170+
-- Aggregate with multiple distinct decimal columns
171+
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col);

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 56
2+
-- Number of queries: 57
33

44

55
-- !query
@@ -573,3 +573,11 @@ org.apache.spark.sql.AnalysisException
573573
Aggregate/Window/Generate expressions are not valid in where clause of the query.
574574
Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))]
575575
Invalid expressions: [count(1), max(test_agg.`k`)];
576+
577+
578+
-- !query
579+
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col)
580+
-- !query schema
581+
struct<avg(DISTINCT decimal_col):decimal(13,4),sum(DISTINCT decimal_col):decimal(19,0)>
582+
-- !query output
583+
1.0000 1

0 commit comments

Comments
 (0)