Skip to content

Commit 416cd1f

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set
## What changes were proposed in this pull request? bring back apache#21443 This is a different approach: just change the check to count distinct columns with `toSet` ## How was this patch tested? a new test to verify the planner behavior. Author: Wenchen Fan <[email protected]> Author: Takeshi Yamamuro <[email protected]> Closes apache#21487 from cloud-fan/back.
1 parent a2166ec commit 416cd1f

4 files changed

Lines changed: 38 additions & 4 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
384384

385385
val (functionsWithDistinct, functionsWithoutDistinct) =
386386
aggregateExpressions.partition(_.isDistinct)
387-
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
387+
if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
388388
// This is a sanity check. We should not reach here when we have multiple distinct
389-
// column sets. Our MultipleDistinctRewriter should take care this case.
389+
// column sets. Our `RewriteDistinctAggregates` should take care this case.
390390
sys.error("You hit a query analyzer bug. Please report your query to " +
391391
"Spark user mailing list.")
392392
}

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,8 @@ SELECT 1 from (
6868
FROM (select 1 as x) a
6969
WHERE false
7070
) b
71-
where b.z != b.z
71+
where b.z != b.z;
72+
73+
-- SPARK-24369 multiple distinct aggregations having the same argument set
74+
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
75+
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y);

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 26
2+
-- Number of queries: 27
33

44

55
-- !query 0
@@ -241,3 +241,12 @@ where b.z != b.z
241241
struct<1:int>
242242
-- !query 25 output
243243

244+
245+
246+
-- !query 26
247+
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
248+
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y)
249+
-- !query 26 schema
250+
struct<corr(DISTINCT CAST(x AS DOUBLE), CAST(y AS DOUBLE)):double,corr(DISTINCT CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,count(1):bigint>
251+
-- !query 26 output
252+
1.0 1.0 3

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext {
6969
testPartialAggregationPlan(query)
7070
}
7171

72+
test("mixed aggregates with same distinct columns") {
73+
def assertNoExpand(plan: SparkPlan): Unit = {
74+
assert(plan.collect { case e: ExpandExec => e }.isEmpty)
75+
}
76+
77+
withTempView("v") {
78+
Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v")
79+
// one distinct column
80+
val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i")
81+
assertNoExpand(query1.queryExecution.executedPlan)
82+
83+
// 2 distinct columns
84+
val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i")
85+
assertNoExpand(query2.queryExecution.executedPlan)
86+
87+
// 2 distinct columns with different order
88+
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
89+
assertNoExpand(query3.queryExecution.executedPlan)
90+
}
91+
}
92+
7293
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
7394
def checkPlan(fieldTypes: Seq[DataType]): Unit = {
7495
withTempView("testLimit") {

0 commit comments

Comments
 (0)