-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-37839][SQL] DS V2 supports partial aggregate push-down AVG
#35130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
fc7df3a
44b0e83
e632527
3406f61
be4111d
961aed5
9c691cb
597474d
4ffd9f6
40d8f59
225a1e1
b7b8868
6d3379f
f19b045
4df1664
c9d1e25
f6b8de9
eaac5fc
ed0019b
3115696
db03020
d41aeb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.connector.expressions.aggregate; | ||
|
|
||
| import org.apache.spark.annotation.Evolving; | ||
| import org.apache.spark.sql.connector.expressions.NamedReference; | ||
|
|
||
| /** | ||
| * An aggregate function that returns the mean of all the values in a group. | ||
| * | ||
| * @since 3.3.0 | ||
| */ | ||
| @Evolving | ||
| public final class Avg implements AggregateFunc { | ||
| private final NamedReference column; | ||
| private final boolean isDistinct; | ||
|
|
||
| public Avg(NamedReference column, boolean isDistinct) { | ||
| this.column = column; | ||
| this.isDistinct = isDistinct; | ||
| } | ||
|
|
||
| public NamedReference column() { return column; } | ||
| public boolean isDistinct() { return isDistinct; } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| if (isDistinct) { | ||
| return "AVG(DISTINCT " + column.describe() + ")"; | ||
| } else { | ||
| return "AVG(" + column.describe() + ")"; | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 | |||||
|
|
||||||
| import scala.collection.mutable | ||||||
|
|
||||||
| import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} | ||||||
| import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} | ||||||
| import org.apache.spark.sql.catalyst.expressions.aggregate | ||||||
| import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||||||
| import org.apache.spark.sql.catalyst.planning.ScanOperation | ||||||
|
|
@@ -88,25 +88,49 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
| sHolder.builder match { | ||||||
| case r: SupportsPushDownAggregates => | ||||||
| val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] | ||||||
| var ordinal = 0 | ||||||
| val aggregates = resultExpressions.flatMap { expr => | ||||||
| expr.collect { | ||||||
| // Do not push down duplicated aggregate expressions. For example, | ||||||
| // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one | ||||||
| // `max(a)` to the data source. | ||||||
| case agg: AggregateExpression | ||||||
| if !aggExprToOutputOrdinal.contains(agg.canonicalized) => | ||||||
| aggExprToOutputOrdinal(agg.canonicalized) = ordinal | ||||||
| ordinal += 1 | ||||||
| agg | ||||||
| } | ||||||
| } | ||||||
| val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) | ||||||
| val normalizedAggregates = DataSourceStrategy.normalizeExprs( | ||||||
| aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] | ||||||
| val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( | ||||||
| groupingExpressions, sHolder.relation.output) | ||||||
| val pushedAggregates = PushDownUtils.pushAggregates( | ||||||
| r, normalizedAggregates, normalizedGroupingExpressions) | ||||||
| val translatedAggregates = DataSourceStrategy.translateAggregation( | ||||||
| normalizedAggregates, normalizedGroupingExpressions) | ||||||
| val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { | ||||||
| if (translatedAggregates.isEmpty || | ||||||
| r.supportCompletePushDown(translatedAggregates.get)) { | ||||||
| (resultExpressions, aggregates, translatedAggregates) | ||||||
| } else { | ||||||
| // The data source doesn't support the complete push-down of this aggregation. | ||||||
| // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be | ||||||
| // pushed, completely or partially. | ||||||
| var findAverage = false | ||||||
| val newResultExpressions = resultExpressions.map { expr => | ||||||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| expr.transform { | ||||||
| case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => | ||||||
| findAverage = true | ||||||
| val left = addCastIfNeeded(aggregate.Sum(avg.child) | ||||||
| .toAggregateExpression(isDistinct), avg.dataType) | ||||||
|
||||||
| val right = addCastIfNeeded(aggregate.Count(avg.child) | ||||||
| .toAggregateExpression(isDistinct), avg.dataType) | ||||||
| Divide(left, right) | ||||||
|
||||||
| } | ||||||
| }.asInstanceOf[Seq[NamedExpression]] | ||||||
| if (findAverage) { | ||||||
| // Because aggregate expressions changed, translate them again. | ||||||
| aggExprToOutputOrdinal.clear() | ||||||
| val newAggregates = | ||||||
| collectAggregates(newResultExpressions, aggExprToOutputOrdinal) | ||||||
| val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( | ||||||
| newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] | ||||||
| (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( | ||||||
| newNormalizedAggregates, normalizedGroupingExpressions)) | ||||||
| } else { | ||||||
| (resultExpressions, aggregates, translatedAggregates) | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) | ||||||
| if (pushedAggregates.isEmpty) { | ||||||
| aggNode // return original plan node | ||||||
| } else if (!supportPartialAggPushDown(pushedAggregates.get) && | ||||||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
@@ -129,7 +153,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
| // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] | ||||||
| // scalastyle:on | ||||||
| val newOutput = scan.readSchema().toAttributes | ||||||
| assert(newOutput.length == groupingExpressions.length + aggregates.length) | ||||||
| assert(newOutput.length == groupingExpressions.length + finalAggregates.length) | ||||||
| val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { | ||||||
| case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) | ||||||
| case (_, b) => b | ||||||
|
|
@@ -164,25 +188,25 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
| Project(projectExpressions, scanRelation) | ||||||
| } else { | ||||||
| val plan = Aggregate( | ||||||
| output.take(groupingExpressions.length), resultExpressions, scanRelation) | ||||||
| output.take(groupingExpressions.length), finalResultExpressions, scanRelation) | ||||||
|
|
||||||
| // scalastyle:off | ||||||
| // Change the optimized logical plan to reflect the pushed down aggregate | ||||||
| // e.g. TABLE t (c1 INT, c2 INT, c3 INT) | ||||||
| // SELECT min(c1), max(c1) FROM t GROUP BY c2; | ||||||
| // SELECT min(c1), max(c1), avg(c1) FROM t GROUP BY c2; | ||||||
|
||||||
| // The original logical plan is | ||||||
| // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] | ||||||
| // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18, avg(c1#9) AS avg(c1)#19] | ||||||
| // +- RelationV2[c1#9, c2#10] ... | ||||||
| // | ||||||
| // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] | ||||||
| // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] | ||||||
| // we have the following | ||||||
| // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] | ||||||
| // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... | ||||||
| // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18, sum(c1#9)/count(c1#9) AS avg(c1)#19] | ||||||
| // +- RelationV2[c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] ... | ||||||
| // | ||||||
| // We want to change it to | ||||||
| // == Optimized Logical Plan == | ||||||
| // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] | ||||||
| // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... | ||||||
| // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18, sum(sum(c1)#23)/sum(count(c1)#24) AS avg(c1)#19] | ||||||
| // +- RelationV2[c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] ... | ||||||
| // scalastyle:on | ||||||
| plan.transformExpressions { | ||||||
| case agg: AggregateExpression => | ||||||
|
|
@@ -210,16 +234,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| private def collectAggregates(resultExpressions: Seq[NamedExpression], | ||||||
| aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { | ||||||
| var ordinal = 0 | ||||||
| resultExpressions.flatMap { expr => | ||||||
| expr.collect { | ||||||
| // Do not push down duplicated aggregate expressions. For example, | ||||||
| // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one | ||||||
| // `max(a)` to the data source. | ||||||
| case agg: AggregateExpression | ||||||
| if !aggExprToOutputOrdinal.contains(agg.canonicalized) => | ||||||
| aggExprToOutputOrdinal(agg.canonicalized) = ordinal | ||||||
| ordinal += 1 | ||||||
| agg | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| private def supportPartialAggPushDown(agg: Aggregation): Boolean = { | ||||||
| // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. | ||||||
| agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) | ||||||
| } | ||||||
|
|
||||||
| private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = | ||||||
| if (aggAttribute.dataType == aggDataType) { | ||||||
| aggAttribute | ||||||
| private def addCastIfNeeded(expression: Expression, aggDataType: DataType) = | ||||||
|
||||||
| private def addCastIfNeeded(expression: Expression, aggDataType: DataType) = | |
| private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) = |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} | |
| import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} | ||
| import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} | ||
| import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog | ||
| import org.apache.spark.sql.functions.{lit, sum, udf} | ||
| import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} | ||
| import org.apache.spark.sql.test.SharedSparkSession | ||
| import org.apache.spark.util.Utils | ||
|
|
||
|
|
@@ -874,4 +874,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel | |
| checkAnswer(df, Seq(Row(2))) | ||
| // scalastyle:on | ||
| } | ||
|
|
||
| test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") { | ||
| val df = spark.read | ||
| .option("partitionColumn", "dept") | ||
| .option("lowerBound", "0") | ||
| .option("upperBound", "2") | ||
| .option("numPartitions", "2") | ||
| .table("h2.test.employee") | ||
| .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) | ||
| checkAggregateRemoved(df, false) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only verifies the aggregates are not completely pushed down. We still need to check if the partial push down works, right? |
||
| checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) | ||
|
|
||
| val df2 = spark.read | ||
| .option("partitionColumn", "dept") | ||
| .option("lowerBound", "0") | ||
| .option("upperBound", "2") | ||
| .option("numPartitions", "2") | ||
| .table("h2.test.employee") | ||
| .groupBy($"name") | ||
| .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) | ||
| checkAggregateRemoved(df2, false) | ||
| checkAnswer(df2, Seq( | ||
| Row("alex", 12000.00, 12000.000000, 1), | ||
| Row("amy", 10000.00, 10000.000000, 1), | ||
| Row("cathy", 9000.00, 9000.000000, 1), | ||
| Row("david", 10000.00, 10000.000000, 1), | ||
| Row("jen", 12000.00, 12000.000000, 1))) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.