diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 06e3888e7debd..8c420838ca274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -58,7 +58,8 @@ class SparkOptimizer( Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter) :+ Batch("MergeScalarSubqueries", Once, - MergeScalarSubqueries) :+ + MergeScalarSubqueries, + RewriteDistinctAggregates) :+ Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5d667bbdd8cde..e61f43d8847a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2598,4 +2598,29 @@ class SubquerySuite extends QueryTest Row("aa")) } } + + test("SPARK-42346: Rewrite distinct aggregates after merging subqueries") { + withTempView("t1") { + Seq((1, 2), (3, 4)).toDF("c1", "c2").createOrReplaceTempView("t1") + + checkAnswer(sql( + """ + |SELECT + | (SELECT count(distinct c1) FROM t1), + | (SELECT count(distinct c2) FROM t1) + |""".stripMargin), + Row(2, 2)) + + // In this case we don't merge the subqueries as `RewriteDistinctAggregates` kicks off for the + // 2 subqueries first but `MergeScalarSubqueries` is not prepared for the `Expand` nodes that + // are inserted by the rewrite. + checkAnswer(sql( + """ + |SELECT + | (SELECT count(distinct c1) + sum(distinct c2) FROM t1), + | (SELECT count(distinct c2) + sum(distinct c1) FROM t1) + |""".stripMargin), + Row(8, 6)) + } + } }