Skip to content

Commit 08989f0

Browse files
belieferchenzhx
authored andcommitted
[SPARK-37839][SQL] DS V2 supports partial aggregate push-down AVG
### What changes were proposed in this pull request? `max`,`min`,`count`,`sum`,`avg` are the most commonly used aggregation functions. Currently, DS V2 supports complete aggregate push-down of `avg`. But, supports partial aggregate push-down of `avg` is very useful. The aggregate push-down algorithm is: 1. Spark translates group expressions of `Aggregate` to DS V2 `Aggregation`. 2. Spark calls `supportCompletePushDown` to check if it can completely push down aggregate. 3. If `supportCompletePushDown` returns true, we preserves the aggregate expressions as final aggregate expressions. Otherwise, we split `AVG` into 2 functions: `SUM` and `COUNT`. 4. Spark translates final aggregate expressions and group expressions of `Aggregate` to DS V2 `Aggregation` again, and pushes the `Aggregation` to JDBC source. 5. Spark constructs the final aggregate. ### Why are the changes needed? DS V2 supports partial aggregate push-down `AVG` ### Does this PR introduce _any_ user-facing change? 'Yes'. DS V2 could partial aggregate push-down `AVG` ### How was this patch tested? New tests. Closes apache#35130 from beliefer/SPARK-37839. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 0e50f11 commit 08989f0

8 files changed

Lines changed: 250 additions & 72 deletions

File tree

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.expressions.aggregate;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.expressions.NamedReference;
22+
23+
/**
24+
* An aggregate function that returns the mean of all the values in a group.
25+
*
26+
* @since 3.3.0
27+
*/
28+
@Evolving
29+
public final class Avg implements AggregateFunc {
30+
private final NamedReference column;
31+
private final boolean isDistinct;
32+
33+
public Avg(NamedReference column, boolean isDistinct) {
34+
this.column = column;
35+
this.isDistinct = isDistinct;
36+
}
37+
38+
public NamedReference column() { return column; }
39+
public boolean isDistinct() { return isDistinct; }
40+
41+
@Override
42+
public String toString() {
43+
if (isDistinct) {
44+
return "AVG(DISTINCT " + column.describe() + ")";
45+
} else {
46+
return "AVG(" + column.describe() + ")";
47+
}
48+
}
49+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
* <p>
3232
* The currently supported SQL aggregate functions:
3333
* <ol>
34-
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
3534
* <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
3635
* <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
3736
* <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class Average(
6969
case _ => DoubleType
7070
}
7171

72-
private lazy val sumDataType = child.dataType match {
72+
lazy val sumDataType = child.dataType match {
7373
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
7474
case _: YearMonthIntervalType => YearMonthIntervalType()
7575
case _: DayTimeIntervalType => DayTimeIntervalType()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
4141
import org.apache.spark.sql.connector.catalog.SupportsRead
4242
import org.apache.spark.sql.connector.catalog.TableCapability._
4343
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
44-
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
44+
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
4545
import org.apache.spark.sql.errors.QueryCompilationErrors
4646
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
4747
import org.apache.spark.sql.execution.command._
@@ -717,7 +717,7 @@ object DataSourceStrategy
717717
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
718718
Some(new Sum(FieldReference(name), aggregates.isDistinct))
719719
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
720-
Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name))))
720+
Some(new Avg(FieldReference(name), aggregates.isDistinct))
721721
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
722722
Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name))))
723723
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
@@ -746,6 +746,31 @@ object DataSourceStrategy
746746
}
747747
}
748748

749+
/**
750+
* Translate aggregate expressions and group by expressions.
751+
*
752+
* @return translated aggregation.
753+
*/
754+
protected[sql] def translateAggregation(
755+
aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {
756+
757+
def columnAsString(e: Expression): Option[FieldReference] = e match {
758+
case PushableColumnWithoutNestedColumn(name) =>
759+
Some(FieldReference.column(name).asInstanceOf[FieldReference])
760+
case _ => None
761+
}
762+
763+
val translatedAggregates = aggregates.flatMap(translateAggregate)
764+
val translatedGroupBys = groupBy.flatMap(columnAsString)
765+
766+
if (translatedAggregates.length != aggregates.length ||
767+
translatedGroupBys.length != groupBy.length) {
768+
return None
769+
}
770+
771+
Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
772+
}
773+
749774
protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
750775
def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
751776
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
23-
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2423
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
25-
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
26-
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
24+
import org.apache.spark.sql.connector.expressions.SortOrder
2725
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
28-
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
29-
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
26+
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
27+
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3028
import org.apache.spark.sql.internal.SQLConf
3129
import org.apache.spark.sql.sources
3230
import org.apache.spark.sql.types.StructType
@@ -103,38 +101,6 @@ object PushDownUtils extends PredicateHelper {
103101
}
104102
}
105103

106-
/**
107-
* Pushes down aggregates to the data source reader
108-
*
109-
* @return pushed aggregation.
110-
*/
111-
def pushAggregates(
112-
scanBuilder: ScanBuilder,
113-
aggregates: Seq[AggregateExpression],
114-
groupBy: Seq[Expression]): Option[Aggregation] = {
115-
116-
def columnAsString(e: Expression): Option[FieldReference] = e match {
117-
case PushableColumnWithoutNestedColumn(name) =>
118-
Some(FieldReference(name).asInstanceOf[FieldReference])
119-
case _ => None
120-
}
121-
122-
scanBuilder match {
123-
case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
124-
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
125-
val translatedGroupBys = groupBy.flatMap(columnAsString)
126-
127-
if (translatedAggregates.length != aggregates.length ||
128-
translatedGroupBys.length != groupBy.length) {
129-
return None
130-
}
131-
132-
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
133-
Some(agg).filter(r.pushAggregation)
134-
case _ => None
135-
}
136-
}
137-
138104
/**
139105
* Pushes down TableSample to the data source Scan
140106
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
2323
import org.apache.spark.sql.catalyst.expressions.aggregate
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2525
import org.apache.spark.sql.catalyst.planning.ScanOperation
2626
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
2727
import org.apache.spark.sql.catalyst.rules.Rule
2828
import org.apache.spark.sql.connector.expressions.SortOrder
29-
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
29+
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc}
3030
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
3131
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3232
import org.apache.spark.sql.sources
33-
import org.apache.spark.sql.types.{DataType, LongType, StructType}
33+
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
3434
import org.apache.spark.sql.util.SchemaUtils._
3535

3636
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -86,27 +86,68 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
8686
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
8787
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
8888
sHolder.builder match {
89-
case _: SupportsPushDownAggregates =>
89+
case r: SupportsPushDownAggregates =>
9090
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
91-
var ordinal = 0
92-
val aggregates = resultExpressions.flatMap { expr =>
93-
expr.collect {
94-
// Do not push down duplicated aggregate expressions. For example,
95-
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
96-
// `max(a)` to the data source.
97-
case agg: AggregateExpression
98-
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
99-
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
100-
ordinal += 1
101-
agg
102-
}
103-
}
91+
val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
10492
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
10593
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
10694
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
10795
groupingExpressions, sHolder.relation.output)
108-
val pushedAggregates = PushDownUtils.pushAggregates(
109-
sHolder.builder, normalizedAggregates, normalizedGroupingExpressions)
96+
val translatedAggregates = DataSourceStrategy.translateAggregation(
97+
normalizedAggregates, normalizedGroupingExpressions)
98+
val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
99+
if (translatedAggregates.isEmpty ||
100+
r.supportCompletePushDown(translatedAggregates.get) ||
101+
translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
102+
(resultExpressions, aggregates, translatedAggregates)
103+
} else {
104+
// scalastyle:off
105+
// The data source doesn't support the complete push-down of this aggregation.
106+
// Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be
107+
// pushed, completely or partially.
108+
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
109+
// SELECT avg(c1) FROM t GROUP BY c2;
110+
// The original logical plan is
111+
// Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
112+
// +- ScanOperation[...]
113+
//
114+
// After convert avg(c1#9) to sum(c1#9)/count(c1#9)
115+
// we have the following
116+
// Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
117+
// +- ScanOperation[...]
118+
// scalastyle:on
119+
val newResultExpressions = resultExpressions.map { expr =>
120+
expr.transform {
121+
case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
122+
val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
123+
val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct)
124+
// Closely follow `Average.evaluateExpression`
125+
avg.dataType match {
126+
case _: YearMonthIntervalType =>
127+
If(EqualTo(count, Literal(0L)),
128+
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
129+
case _: DayTimeIntervalType =>
130+
If(EqualTo(count, Literal(0L)),
131+
Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count))
132+
case _ =>
133+
// TODO deal with the overflow issue
134+
Divide(addCastIfNeeded(sum, avg.dataType),
135+
addCastIfNeeded(count, avg.dataType), false)
136+
}
137+
}
138+
}.asInstanceOf[Seq[NamedExpression]]
139+
// Because aggregate expressions changed, translate them again.
140+
aggExprToOutputOrdinal.clear()
141+
val newAggregates =
142+
collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
143+
val newNormalizedAggregates = DataSourceStrategy.normalizeExprs(
144+
newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
145+
(newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation(
146+
newNormalizedAggregates, normalizedGroupingExpressions))
147+
}
148+
}
149+
150+
val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
110151
if (pushedAggregates.isEmpty) {
111152
aggNode // return original plan node
112153
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
@@ -129,7 +170,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
129170
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
130171
// scalastyle:on
131172
val newOutput = scan.readSchema().toAttributes
132-
assert(newOutput.length == groupingExpressions.length + aggregates.length)
173+
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
133174
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
134175
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
135176
case (_, b) => b
@@ -164,7 +205,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
164205
Project(projectExpressions, scanRelation)
165206
} else {
166207
val plan = Aggregate(
167-
output.take(groupingExpressions.length), resultExpressions, scanRelation)
208+
output.take(groupingExpressions.length), finalResultExpressions, scanRelation)
168209

169210
// scalastyle:off
170211
// Change the optimized logical plan to reflect the pushed down aggregate
@@ -210,16 +251,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
210251
}
211252
}
212253

254+
private def collectAggregates(resultExpressions: Seq[NamedExpression],
255+
aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = {
256+
var ordinal = 0
257+
resultExpressions.flatMap { expr =>
258+
expr.collect {
259+
// Do not push down duplicated aggregate expressions. For example,
260+
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
261+
// `max(a)` to the data source.
262+
case agg: AggregateExpression
263+
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
264+
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
265+
ordinal += 1
266+
agg
267+
}
268+
}
269+
}
270+
213271
private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
214272
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
215273
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
216274
}
217275

218-
private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
219-
if (aggAttribute.dataType == aggDataType) {
220-
aggAttribute
276+
private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) =
277+
if (expression.dataType == expectedDataType) {
278+
expression
221279
} else {
222-
Cast(aggAttribute, aggDataType)
280+
Cast(expression, expectedDataType)
223281
}
224282

225283
def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
3030
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
3131
import org.apache.spark.sql.connector.catalog.TableChange
3232
import org.apache.spark.sql.connector.catalog.TableChange._
33-
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
33+
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
3434
import org.apache.spark.sql.errors.QueryCompilationErrors
3535
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
3636
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -217,10 +217,11 @@ abstract class JdbcDialect extends Serializable with Logging{
217217
Some(s"SUM($distinct$column)")
218218
case _: CountStar =>
219219
Some("COUNT(*)")
220-
case f: GeneralAggregateFunc if f.name() == "AVG" =>
221-
assert(f.inputs().length == 1)
222-
val distinct = if (f.isDistinct) "DISTINCT " else ""
223-
Some(s"AVG($distinct${f.inputs().head})")
220+
case avg: Avg =>
221+
if (avg.column.fieldNames.length != 1) return None
222+
val distinct = if (avg.isDistinct) "DISTINCT " else ""
223+
val column = quoteIdentifier(avg.column.fieldNames.head)
224+
Some(s"AVG($distinct$column)")
224225
case _ => None
225226
}
226227
}

0 commit comments

Comments
 (0)