Skip to content

Commit 7fa4ccc

Browse files
belieferchenzhx
authored andcommitted
[SPARK-39135][SQL] DS V2 aggregate partial push-down should supports group by without aggregate functions
### What changes were proposed in this pull request? Currently, the SQL show below not supported by DS V2 aggregate partial push-down. `select key from tab group by key` ### Why are the changes needed? Make DS V2 aggregate partial push-down supports group by without aggregate functions. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests Closes apache#36492 from beliefer/SPARK-39135. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 9772774 commit 7fa4ccc

2 files changed

Lines changed: 52 additions & 1 deletion

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
286286
private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
287287
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
288288
// If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down.
289-
agg.aggregateExpressions().exists {
289+
agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists {
290290
case sum: Sum => !sum.isDistinct
291291
case count: Count => !count.isDistinct
292292
case avg: Avg => !avg.isDistinct

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,57 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
670670
checkAnswer(df, Seq(Row(5)))
671671
}
672672

673+
test("scan with aggregate push-down: GROUP BY without aggregate functions") {
674+
val df = sql("select name FROM h2.test.employee GROUP BY name")
675+
checkAggregateRemoved(df)
676+
checkPushedInfo(df,
677+
"PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],")
678+
checkAnswer(df, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen")))
679+
680+
val df2 = spark.read
681+
.option("partitionColumn", "dept")
682+
.option("lowerBound", "0")
683+
.option("upperBound", "2")
684+
.option("numPartitions", "2")
685+
.table("h2.test.employee")
686+
.groupBy($"name")
687+
.agg(Map.empty[String, String])
688+
checkAggregateRemoved(df2, false)
689+
checkPushedInfo(df2,
690+
"PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],")
691+
checkAnswer(df2, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen")))
692+
693+
val df3 = sql("SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as" +
694+
" key FROM h2.test.employee GROUP BY key")
695+
checkAggregateRemoved(df3)
696+
checkPushedInfo(df3,
697+
"""
698+
|PushedAggregates: [],
699+
|PushedFilters: [],
700+
|PushedGroupByExpressions:
701+
|[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END],
702+
|""".stripMargin.replaceAll("\n", " "))
703+
checkAnswer(df3, Seq(Row(0), Row(9000)))
704+
705+
val df4 = spark.read
706+
.option("partitionColumn", "dept")
707+
.option("lowerBound", "0")
708+
.option("upperBound", "2")
709+
.option("numPartitions", "2")
710+
.table("h2.test.employee")
711+
.groupBy(when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0).as("key"))
712+
.agg(Map.empty[String, String])
713+
checkAggregateRemoved(df4, false)
714+
checkPushedInfo(df4,
715+
"""
716+
|PushedAggregates: [],
717+
|PushedFilters: [],
718+
|PushedGroupByExpressions:
719+
|[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END],
720+
|""".stripMargin.replaceAll("\n", " "))
721+
checkAnswer(df4, Seq(Row(0), Row(9000)))
722+
}
723+
673724
test("scan with aggregate push-down: COUNT(col)") {
674725
val df = sql("select COUNT(DEPT) FROM h2.test.employee")
675726
checkAggregateRemoved(df)

0 commit comments

Comments
 (0)