Skip to content

Commit 990a4dc

Browse files
belieferchenzhx
authored andcommitted
[SPARK-39819][SQL] DS V2 aggregate push down can work with Top N or Paging (Sort with expressions)
### What changes were proposed in this pull request? Currently, DS V2 aggregate push-down cannot work with DS V2 Top N push-down (`ORDER BY col LIMIT m`) or DS V2 Paging push-down (`ORDER BY col LIMIT m OFFSET n`). If we can push down aggregate with Top N or Paging, it will be better performance. This PR only let aggregate pushed down with ORDER BY expressions which must be GROUP BY expressions. The idea of this PR are: 1. When we give an expectation outputs of `ScanBuilderHolder`, holding the map from expectation outputs to origin expressions (contains origin columns). 2. When we try to push down Top N or Paging, we need restore the origin expressions for `SortOrder`. ### Why are the changes needed? Let DS V2 aggregate push down can work with Top N or Paging (Sort with group expressions), then users can get the better performance. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes apache#37320 from beliefer/SPARK-39819_new. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 6cf42aa commit 990a4dc

File tree

3 files changed

+244
-51
lines changed

3 files changed

+244
-51
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,9 @@ public Cast(Expression expression, DataType dataType) {
4242

4343
@Override
4444
public Expression[] children() { return new Expression[]{ expression() }; }
45+
46+
@Override
47+
public String toString() {
48+
return "CAST(" + expression.describe() + " AS " + dataType.typeName() + ")";
49+
}
4550
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
22+
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
2525
import org.apache.spark.sql.catalyst.planning.ScanOperation
@@ -189,12 +189,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
189189
// +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22]
190190
// Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation.
191191
// scalastyle:on
192-
val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) =>
193-
AttributeReference(s"group_col_$i", e.dataType)()
192+
val groupOutputMap = normalizedGroupingExpr.zipWithIndex.map { case (e, i) =>
193+
AttributeReference(s"group_col_$i", e.dataType)() -> e
194194
}
195-
val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) =>
196-
AttributeReference(s"agg_func_$i", e.dataType)()
195+
val groupOutput = groupOutputMap.unzip._1
196+
val aggOutputMap = finalAggExprs.zipWithIndex.map { case (e, i) =>
197+
AttributeReference(s"agg_func_$i", e.dataType)() -> e
197198
}
199+
val aggOutput = aggOutputMap.unzip._1
198200
val newOutput = groupOutput ++ aggOutput
199201
val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
200202
normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) =>
@@ -204,6 +206,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
204206
}
205207

206208
holder.pushedAggregate = Some(translatedAgg)
209+
holder.pushedAggOutputMap = AttributeMap(groupOutputMap ++ aggOutputMap)
207210
holder.output = newOutput
208211
logInfo(
209212
s"""
@@ -406,15 +409,21 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
406409
sHolder.pushedLimit = Some(limit)
407410
}
408411
(operation, isPushed && !isPartiallyPushed)
409-
case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder))
410-
// Without building the Scan, we do not know the resulting column names after aggregate
411-
// push-down, and thus can't push down Top-N which needs to know the ordering column names.
412-
// TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same
413-
// columns, which we know the resulting column names: the original table columns.
414-
if sHolder.pushedAggregate.isEmpty && filter.isEmpty &&
415-
CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) =>
412+
case s @ Sort(order, _, operation @ PhysicalOperation(project, Nil, sHolder: ScanBuilderHolder))
413+
if CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) =>
416414
val aliasMap = getAliasMap(project)
417-
val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]]
415+
val aliasReplacedOrder = order.map(replaceAlias(_, aliasMap))
416+
val newOrder = if (sHolder.pushedAggregate.isDefined) {
417+
// `ScanBuilderHolder` has different output columns after aggregate push-down. Here we
418+
// replace the attributes in ordering expressions with the original table output columns.
419+
aliasReplacedOrder.map {
420+
_.transform {
421+
case a: Attribute => sHolder.pushedAggOutputMap.getOrElse(a, a)
422+
}.asInstanceOf[SortOrder]
423+
}
424+
} else {
425+
aliasReplacedOrder.asInstanceOf[Seq[SortOrder]]
426+
}
418427
val normalizedOrders = DataSourceStrategy.normalizeExprs(
419428
newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
420429
val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
@@ -544,6 +553,8 @@ case class ScanBuilderHolder(
544553
var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate]
545554

546555
var pushedAggregate: Option[Aggregation] = None
556+
557+
var pushedAggOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]
547558
}
548559

549560
// A wrapper for v1 scan to carry the translated filters and the handled ones, along with

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 215 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -774,59 +774,46 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
774774
checkAnswer(df5,
775775
Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true)))
776776

777+
val name = udf { (x: String) => x.matches("cat|dav|amy") }
778+
val sub = udf { (x: String) => x.substring(0, 3) }
777779
val df6 = spark.read
778780
.table("h2.test.employee")
779-
.groupBy("DEPT").sum("SALARY")
780-
.orderBy("DEPT")
781+
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
782+
.filter(name($"shortName"))
783+
.sort($"SALARY".desc)
781784
.limit(1)
785+
// LIMIT is pushed down only if all the filters are pushed down
782786
checkSortRemoved(df6, false)
783787
checkLimitRemoved(df6, false)
784-
checkPushedInfo(df6,
785-
"PushedAggregates: [SUM(SALARY)]",
786-
"PushedFilters: []",
787-
"PushedGroupByExpressions: [DEPT]")
788-
checkAnswer(df6, Seq(Row(1, 19000.00)))
788+
checkPushedInfo(df6, "PushedFilters: []")
789+
checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy")))
789790

790-
val name = udf { (x: String) => x.matches("cat|dav|amy") }
791-
val sub = udf { (x: String) => x.substring(0, 3) }
792791
val df7 = spark.read
793792
.table("h2.test.employee")
794-
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
795-
.filter(name($"shortName"))
796-
.sort($"SALARY".desc)
793+
.sort(sub($"NAME"))
797794
.limit(1)
798-
// LIMIT is pushed down only if all the filters are pushed down
799795
checkSortRemoved(df7, false)
800796
checkLimitRemoved(df7, false)
801797
checkPushedInfo(df7, "PushedFilters: []")
802-
checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy")))
798+
checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
803799

804800
val df8 = spark.read
805-
.table("h2.test.employee")
806-
.sort(sub($"NAME"))
807-
.limit(1)
808-
checkSortRemoved(df8, false)
809-
checkLimitRemoved(df8, false)
810-
checkPushedInfo(df8, "PushedFilters: []")
811-
checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
812-
813-
val df9 = spark.read
814801
.table("h2.test.employee")
815802
.select($"DEPT", $"name", $"SALARY",
816803
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
817804
.sort("key", "dept", "SALARY")
818805
.limit(3)
819-
checkSortRemoved(df9)
820-
checkLimitRemoved(df9)
821-
checkPushedInfo(df9,
806+
checkSortRemoved(df8)
807+
checkLimitRemoved(df8)
808+
checkPushedInfo(df8,
822809
"PushedFilters: []",
823-
"PushedTopN: " +
824-
"ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
825-
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,")
826-
checkAnswer(df9,
810+
"PushedTopN: ORDER BY " +
811+
"[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END" +
812+
" ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3")
813+
checkAnswer(df8,
827814
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))
828815

829-
val df10 = spark.read
816+
val df9 = spark.read
830817
.option("partitionColumn", "dept")
831818
.option("lowerBound", "0")
832819
.option("upperBound", "2")
@@ -836,14 +823,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
836823
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
837824
.orderBy($"key", $"dept", $"SALARY")
838825
.limit(3)
839-
checkSortRemoved(df10, false)
840-
checkLimitRemoved(df10, false)
841-
checkPushedInfo(df10,
826+
checkSortRemoved(df9, false)
827+
checkLimitRemoved(df9, false)
828+
checkPushedInfo(df9,
842829
"PushedFilters: []",
843-
"PushedTopN: " +
844-
"ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
845-
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,")
846-
checkAnswer(df10,
830+
"PushedTopN: ORDER BY " +
831+
"[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
832+
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3")
833+
checkAnswer(df9,
847834
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))
848835
}
849836

@@ -872,6 +859,196 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
872859
checkAnswer(df2, Seq(Row(2, "david", 10000.00)))
873860
}
874861

862+
test("scan with aggregate push-down, top N push-down and offset push-down") {
863+
val df1 = spark.read
864+
.table("h2.test.employee")
865+
.groupBy("DEPT").sum("SALARY")
866+
.orderBy("DEPT")
867+
868+
val paging1 = df1.offset(1).limit(1)
869+
checkSortRemoved(paging1)
870+
checkLimitRemoved(paging1)
871+
checkPushedInfo(paging1,
872+
"PushedAggregates: [SUM(SALARY)]",
873+
"PushedGroupByExpressions: [DEPT]",
874+
"PushedFilters: []",
875+
"PushedOffset: OFFSET 1",
876+
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2")
877+
checkAnswer(paging1, Seq(Row(2, 22000.00)))
878+
879+
val topN1 = df1.limit(1)
880+
checkSortRemoved(topN1)
881+
checkLimitRemoved(topN1)
882+
checkPushedInfo(topN1,
883+
"PushedAggregates: [SUM(SALARY)]",
884+
"PushedGroupByExpressions: [DEPT]",
885+
"PushedFilters: []",
886+
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1")
887+
checkAnswer(topN1, Seq(Row(1, 19000.00)))
888+
889+
val df2 = spark.read
890+
.table("h2.test.employee")
891+
.select($"DEPT".cast("string").as("my_dept"), $"SALARY")
892+
.groupBy("my_dept").sum("SALARY")
893+
.orderBy("my_dept")
894+
895+
val paging2 = df2.offset(1).limit(1)
896+
checkSortRemoved(paging2)
897+
checkLimitRemoved(paging2)
898+
checkPushedInfo(paging2,
899+
"PushedAggregates: [SUM(SALARY)]",
900+
"PushedGroupByExpressions: [CAST(DEPT AS string)]",
901+
"PushedFilters: []",
902+
"PushedOffset: OFFSET 1",
903+
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2")
904+
checkAnswer(paging2, Seq(Row("2", 22000.00)))
905+
906+
val topN2 = df2.limit(1)
907+
checkSortRemoved(topN2)
908+
checkLimitRemoved(topN2)
909+
checkPushedInfo(topN2,
910+
"PushedAggregates: [SUM(SALARY)]",
911+
"PushedGroupByExpressions: [CAST(DEPT AS string)]",
912+
"PushedFilters: []",
913+
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1")
914+
checkAnswer(topN2, Seq(Row("1", 19000.00)))
915+
916+
val df3 = spark.read
917+
.table("h2.test.employee")
918+
.groupBy("dept").sum("SALARY")
919+
.orderBy($"dept".cast("string"))
920+
921+
val paging3 = df3.offset(1).limit(1)
922+
checkSortRemoved(paging3)
923+
checkLimitRemoved(paging3)
924+
checkPushedInfo(paging3,
925+
"PushedAggregates: [SUM(SALARY)]",
926+
"PushedGroupByExpressions: [DEPT]",
927+
"PushedFilters: []",
928+
"PushedOffset: OFFSET 1",
929+
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2")
930+
checkAnswer(paging3, Seq(Row(2, 22000.00)))
931+
932+
val topN3 = df3.limit(1)
933+
checkSortRemoved(topN3)
934+
checkLimitRemoved(topN3)
935+
checkPushedInfo(topN3,
936+
"PushedAggregates: [SUM(SALARY)]",
937+
"PushedGroupByExpressions: [DEPT]",
938+
"PushedFilters: []",
939+
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1")
940+
checkAnswer(topN3, Seq(Row(1, 19000.00)))
941+
942+
val df4 = spark.read
943+
.table("h2.test.employee")
944+
.groupBy("DEPT", "IS_MANAGER").sum("SALARY")
945+
.orderBy("DEPT", "IS_MANAGER")
946+
947+
val paging4 = df4.offset(1).limit(1)
948+
checkSortRemoved(paging4)
949+
checkLimitRemoved(paging4)
950+
checkPushedInfo(paging4,
951+
"PushedAggregates: [SUM(SALARY)]",
952+
"PushedGroupByExpressions: [DEPT, IS_MANAGER]",
953+
"PushedFilters: []",
954+
"PushedOffset: OFFSET 1",
955+
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2")
956+
checkAnswer(paging4, Seq(Row(1, true, 10000.00)))
957+
958+
val topN4 = df4.limit(1)
959+
checkSortRemoved(topN4)
960+
checkLimitRemoved(topN4)
961+
checkPushedInfo(topN4,
962+
"PushedAggregates: [SUM(SALARY)]",
963+
"PushedGroupByExpressions: [DEPT, IS_MANAGER]",
964+
"PushedFilters: []",
965+
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1")
966+
checkAnswer(topN4, Seq(Row(1, false, 9000.00)))
967+
968+
val df5 = spark.read
969+
.table("h2.test.employee")
970+
.select($"SALARY", $"IS_MANAGER", $"DEPT".cast("string").as("my_dept"))
971+
.groupBy("my_dept", "IS_MANAGER").sum("SALARY")
972+
.orderBy("my_dept", "IS_MANAGER")
973+
974+
val paging5 = df5.offset(1).limit(1)
975+
checkSortRemoved(paging5)
976+
checkLimitRemoved(paging5)
977+
checkPushedInfo(paging5,
978+
"PushedAggregates: [SUM(SALARY)]",
979+
"PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]",
980+
"PushedFilters: []",
981+
"PushedOffset: OFFSET 1",
982+
"PushedTopN: " +
983+
"ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2")
984+
checkAnswer(paging5, Seq(Row("1", true, 10000.00)))
985+
986+
val topN5 = df5.limit(1)
987+
checkSortRemoved(topN5)
988+
checkLimitRemoved(topN5)
989+
checkPushedInfo(topN5,
990+
"PushedAggregates: [SUM(SALARY)]",
991+
"PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]",
992+
"PushedFilters: []",
993+
"PushedTopN: " +
994+
"ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1")
995+
checkAnswer(topN5, Seq(Row("1", false, 9000.00)))
996+
997+
val df6 = spark.read
998+
.table("h2.test.employee")
999+
.select($"DEPT", $"SALARY")
1000+
.groupBy("dept").agg(sum("SALARY"))
1001+
.orderBy(sum("SALARY"))
1002+
1003+
val paging6 = df6.offset(1).limit(1)
1004+
checkSortRemoved(paging6)
1005+
checkLimitRemoved(paging6)
1006+
checkPushedInfo(paging6,
1007+
"PushedAggregates: [SUM(SALARY)]",
1008+
"PushedGroupByExpressions: [DEPT]",
1009+
"PushedFilters: []",
1010+
"PushedOffset: OFFSET 1",
1011+
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2")
1012+
checkAnswer(paging6, Seq(Row(1, 19000.00)))
1013+
1014+
val topN6 = df6.limit(1)
1015+
checkSortRemoved(topN6)
1016+
checkLimitRemoved(topN6)
1017+
checkPushedInfo(topN6,
1018+
"PushedAggregates: [SUM(SALARY)]",
1019+
"PushedGroupByExpressions: [DEPT]",
1020+
"PushedFilters: []",
1021+
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1")
1022+
checkAnswer(topN6, Seq(Row(6, 12000.00)))
1023+
1024+
val df7 = spark.read
1025+
.table("h2.test.employee")
1026+
.select($"DEPT", $"SALARY")
1027+
.groupBy("dept").agg(sum("SALARY").as("total"))
1028+
.orderBy("total")
1029+
1030+
val paging7 = df7.offset(1).limit(1)
1031+
checkSortRemoved(paging7)
1032+
checkLimitRemoved(paging7)
1033+
checkPushedInfo(paging7,
1034+
"PushedAggregates: [SUM(SALARY)]",
1035+
"PushedGroupByExpressions: [DEPT]",
1036+
"PushedFilters: []",
1037+
"PushedOffset: OFFSET 1",
1038+
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2")
1039+
checkAnswer(paging7, Seq(Row(1, 19000.00)))
1040+
1041+
val topN7 = df7.limit(1)
1042+
checkSortRemoved(topN7)
1043+
checkLimitRemoved(topN7)
1044+
checkPushedInfo(topN7,
1045+
"PushedAggregates: [SUM(SALARY)]",
1046+
"PushedGroupByExpressions: [DEPT]",
1047+
"PushedFilters: []",
1048+
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1")
1049+
checkAnswer(topN7, Seq(Row(6, 12000.00)))
1050+
}
1051+
8751052
test("scan with filter push-down") {
8761053
val df = spark.table("h2.test.people").filter($"id" > 1)
8771054
checkFiltersRemoved(df)

0 commit comments

Comments
 (0)