Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the mean of all the values in a group.
*
* @since 3.3.0
*/
@Evolving
public final class Avg implements AggregateFunc {
private final NamedReference column;
private final boolean isDistinct;

public Avg(NamedReference column, boolean isDistinct) {
this.column = column;
this.isDistinct = isDistinct;
}

public NamedReference column() { return column; }
public boolean isDistinct() { return isDistinct; }

@Override
public String toString() {
if (isDistinct) {
return "AVG(DISTINCT " + column.describe() + ")";
} else {
return "AVG(" + column.describe() + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -720,7 +720,7 @@ object DataSourceStrategy
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference.column(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name))))
Some(new Avg(FieldReference.column(name), agg.isDistinct))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc(
"VAR_POP", agg.isDistinct, Array(FieldReference.column(name))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,12 @@ object PushDownUtils extends PredicateHelper {
}

/**
* Pushes down aggregates to the data source reader
* Translate aggregate expressions and group by expressions.
*
* @return pushed aggregation.
* @return translated aggregation.
*/
def pushAggregates(
scanBuilder: SupportsPushDownAggregates,
aggregates: Seq[AggregateExpression],
groupBy: Seq[Expression]): Option[Aggregation] = {
def translateAggregation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put it in DataSourceStrategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Expand All @@ -130,8 +128,17 @@ object PushDownUtils extends PredicateHelper {
return None
}

val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
Some(agg).filter(scanBuilder.pushAggregation)
Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
}

/**
* Pushes down aggregates to the data source reader
*
* @return pushed aggregation.
*/
def pushAggregates(
scanBuilder: SupportsPushDownAggregates, aggOpt: Option[Aggregation]): Option[Aggregation] = {
aggOpt.filter(scanBuilder.pushAggregation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem necessary to create a method for only one line of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'm a little uncertain, so I'm waiting for your comment.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
Expand Down Expand Up @@ -87,9 +87,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
sHolder.builder match {
case r: SupportsPushDownAggregates =>
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
groupingExpressions, sHolder.relation.output)
val translatedGroupBys =
PushDownUtils.translateAggregation(Seq.empty, normalizedGroupingExpressions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we put the real aggregate functions here?

Copy link
Contributor Author

@beliefer beliefer Jan 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the aggregates are collected from Aggregate. I want call r.supportCompletePushDown and split AVG into 2 functions: SUM and COUNT first, so we collect AggregateExpression just once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes the v2 source only need the group by columns to decide if it supports complete pushdown or not. Do you have strong evidence to indicate it's true for all v2 sources?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pass aggregate expressions too for all the DS V2.

val finalResultExpressions = if (r.supportCompletePushDown(translatedGroupBys.get)) {
resultExpressions
} else {
resultExpressions.map { expr =>
expr.transform {
case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
Divide(Cast(aggregate.Sum(avg.child).toAggregateExpression(isDistinct),
avg.dataType), Cast(
aggregate.Count(avg.child).toAggregateExpression(isDistinct), avg.dataType))
}
}.asInstanceOf[Seq[NamedExpression]]
}

val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
var ordinal = 0
val aggregates = resultExpressions.flatMap { expr =>
val aggregates = finalResultExpressions.flatMap { expr =>
expr.collect {
// Do not push down duplicated aggregate expressions. For example,
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
Expand All @@ -103,10 +120,10 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
groupingExpressions, sHolder.relation.output)
val pushedAggregates = PushDownUtils.pushAggregates(
r, normalizedAggregates, normalizedGroupingExpressions)
val translatedAggregates = PushDownUtils.translateAggregation(
normalizedAggregates, normalizedGroupingExpressions)

val pushedAggregates = PushDownUtils.pushAggregates(r, translatedAggregates)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
Expand Down Expand Up @@ -164,25 +181,25 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Project(projectExpressions, scanRelation)
} else {
val plan = Aggregate(
output.take(groupingExpressions.length), resultExpressions, scanRelation)
output.take(groupingExpressions.length), finalResultExpressions, scanRelation)

// scalastyle:off
// Change the optimized logical plan to reflect the pushed down aggregate
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
// SELECT min(c1), max(c1), avg(c1) FROM t GROUP BY c2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we reach here, the AVG translation is already done and it's weird to see we mention it in the comment.

// The original logical plan is
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18, avg(c1#9) AS avg(c1)#19]
// +- RelationV2[c1#9, c2#10] ...
//
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24]
// we have the following
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18, sum(c1#9)/count(c1#9) AS avg(c1)#19]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] ...
//
// We want to change it to
// == Optimized Logical Plan ==
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18, sum(sum(c1)#23)/sum(count(c1)#24) AS avg(c1)#19]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] ...
// scalastyle:on
plan.transformExpressions {
case agg: AggregateExpression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -220,10 +220,11 @@ abstract class JdbcDialect extends Serializable with Logging{
Some(s"SUM($distinct$column)")
case _: CountStar =>
Some("COUNT(*)")
case f: GeneralAggregateFunc if f.name() == "AVG" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"AVG($distinct${f.inputs().head})")
case avg: Avg =>
if (avg.column.fieldNames.length != 1) return None
val distinct = if (avg.isDistinct) "DISTINCT " else ""
val column = quoteIdentifier(avg.column.fieldNames.head)
Some(s"AVG($distinct$column)")
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{lit, sum, udf}
import org.apache.spark.sql.functions.{avg, count, lit, max, min, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -874,4 +874,60 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row(2)))
// scalastyle:on
}

test("scan with aggregate push-down: partial push-down MAX, AVG, MIN") {
val df = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.agg(max($"SALARY").as("max"), avg($"SALARY").as("avg"), min($"SALARY").as("min"))
checkAggregateRemoved(df, false)
checkAnswer(df, Seq(Row(12000.00, 10600.000000, 9000.00)))

val df2 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.groupBy($"name")
.agg(max($"SALARY").as("max"), avg($"SALARY").as("avg"), min($"SALARY").as("min"))
checkAggregateRemoved(df2, false)
checkAnswer(df2, Seq(
Row("alex", 12000.00, 12000.000000, 12000.00),
Row("amy", 10000.00, 10000.000000, 10000.00),
Row("cathy", 9000.00, 9000.000000, 9000.00),
Row("david", 10000.00, 10000.000000, 10000.00),
Row("jen", 12000.00, 12000.000000, 12000.00)))
}

test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") {
val df = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
checkAggregateRemoved(df, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only verifies the aggregates are not completely pushed down. We still need to check if the partial push down works, right?

checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))

val df2 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.groupBy($"name")
.agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
checkAggregateRemoved(df2, false)
checkAnswer(df2, Seq(
Row("alex", 12000.00, 12000.000000, 1),
Row("amy", 10000.00, 10000.000000, 1),
Row("cathy", 9000.00, 9000.000000, 1),
Row("david", 10000.00, 10000.000000, 1),
Row("jen", 12000.00, 12000.000000, 1)))
}
}