Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the mean of all the values in a group.
*
* @since 3.3.0
*/
@Evolving
public final class Avg implements AggregateFunc {
private final NamedReference column;
private final boolean isDistinct;

public Avg(NamedReference column, boolean isDistinct) {
this.column = column;
this.isDistinct = isDistinct;
}

public NamedReference column() { return column; }
public boolean isDistinct() { return isDistinct; }

@Override
public String toString() {
if (isDistinct) {
return "AVG(DISTINCT " + column.describe() + ")";
} else {
return "AVG(" + column.describe() + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -720,7 +720,7 @@ object DataSourceStrategy
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference.column(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference.column(name))))
Some(new Avg(FieldReference.column(name), agg.isDistinct))
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc(
"VAR_POP", agg.isDistinct, Array(FieldReference.column(name))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
Expand Down Expand Up @@ -129,7 +129,6 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + aggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
Expand Down Expand Up @@ -169,25 +168,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// scalastyle:off
// Change the optimized logical plan to reflect the pushed down aggregate
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
// SELECT min(c1), max(c1), avg(c1) FROM t GROUP BY c2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we reach here, the AVG translation is already done and it's weird to see we mention it in the comment.

// The original logical plan is
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18, avg(c1#9) AS avg(c1)#19]
// +- RelationV2[c1#9, c2#10] ...
//
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24]
// we have the following
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18, avg(c1#9) AS avg(c1)#19]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] ...
//
// We want to change it to
// == Optimized Logical Plan ==
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18, sum(sum(c1)#23)/sum(count(c1)#24) AS avg(c1)#19]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22, sum(c1)#23, count(c1)#24] ...
// scalastyle:on
plan.transformExpressions {
var skip = 0
plan.transformExpressionsUp {
case agg: AggregateExpression =>
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val aggAttribute = aggOutput(ordinal)
val aggAttribute = aggOutput(ordinal + skip)
val aggFunction: aggregate.AggregateFunction =
agg.aggregateFunction match {
case max: aggregate.Max =>
Expand All @@ -200,7 +200,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
aggregate.Sum(addCastIfNeeded(aggAttribute, LongType))
case other => other
}
agg.copy(aggregateFunction = aggFunction)

aggFunction match {
case avg: aggregate.Average =>
skip += 1
val countAttribute = aggOutput(ordinal + skip)
val divide = Divide(
aggregate.Sum(addCastIfNeeded(aggAttribute, avg.dataType))
.toAggregateExpression(),
aggregate.Sum(addCastIfNeeded(countAttribute, avg.dataType))
.toAggregateExpression())
Cast(divide, avg.dataType)
case _ => agg.copy(aggregateFunction = aggFunction)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ case class JDBCScanBuilder(
if (!jdbcOptions.pushDownAggregate) return false

val dialect = JdbcDialects.get(jdbcOptions.url)
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
val compiledAggs = aggregation.aggregateExpressions.flatMap(
dialect.compileAggregate(_, supportCompletePushDown(aggregation)))
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false

val groupByCols = aggregation.groupByColumns.map { col =>
Expand Down
69 changes: 37 additions & 32 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,43 @@ private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})")
case _ => None
override def compileAggregate(
aggFunction: AggregateFunc, supportCompletePushDown: Boolean = true): Option[String] = {
super.compileAggregate(aggFunction, supportCompletePushDown).orElse(
if (supportCompletePushDown) {
aggFunction match {
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"VAR_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_POP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
case f: GeneralAggregateFunc if f.name() == "COVAR_POP" =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})")
case f: GeneralAggregateFunc if f.name() == "CORR" =>
assert(f.inputs().length == 2)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})")
case _ => None
}
} else {
None
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -197,10 +197,12 @@ abstract class JdbcDialect extends Serializable with Logging{
/**
* Converts aggregate function to String representing a SQL expression.
* @param aggFunction The aggregate function to be converted.
* @param supportCompletePushDown supports complete push-down.
* @return Converted value.
*/
@Since("3.3.0")
def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
def compileAggregate(
aggFunction: AggregateFunc, supportCompletePushDown: Boolean = true): Option[String] = {
aggFunction match {
case min: Min =>
if (min.column.fieldNames.length != 1) return None
Expand All @@ -220,10 +222,16 @@ abstract class JdbcDialect extends Serializable with Logging{
Some(s"SUM($distinct$column)")
case _: CountStar =>
Some("COUNT(*)")
case f: GeneralAggregateFunc if f.name() == "AVG" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"AVG($distinct${f.inputs().head})")
case avg: Avg =>
if (avg.column.fieldNames.length != 1) return None
val distinct = if (avg.isDistinct) "DISTINCT " else ""
val column = quoteIdentifier(avg.column.fieldNames.head)
if (supportCompletePushDown) {
Some(s"AVG($distinct$column)")
} else {
// For simplify code, we not reuse exists `SUM` or `COUNT`.
Some(s"SUM($distinct$column), COUNT($distinct$column)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's super weird to see this translation at the data source side, not the Spark side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since partial agg pushdown needs Spark to do final agg, Spark must be fully aware of the AVG translation, to make the final agg match the data source scan with partial agg pushed.

I think the process should be

  1. Spark translates catalyst Aggregate operator to a DS V2 Aggregation.
  2. Spark calls supportCompletePushDown to check if it can completely push down agg
  3. JDBC source returns false in supportCompletePushDown if AVG is present
  4. Spark gives up complete agg push down, and starts to try partial agg push down
  5. Spark splits AVG into 2 functions: SUM and COUNT, and pushes the Aggregation to JDBC source
  6. Spark constructs the final agg and calculates AVG by SUM / COUNT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Let me have a try.

}
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{lit, sum, udf}
import org.apache.spark.sql.functions.{avg, count, lit, max, min, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -874,4 +874,60 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row(2)))
// scalastyle:on
}

test("scan with aggregate push-down: partial push-down MAX, AVG, MIN") {
val df = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.agg(max($"SALARY").as("max"), avg($"SALARY").as("avg"), min($"SALARY").as("min"))
checkAggregateRemoved(df, false)
checkAnswer(df, Seq(Row(12000.00, 10600.000000, 9000.00)))

val df2 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.groupBy($"name")
.agg(max($"SALARY").as("max"), avg($"SALARY").as("avg"), min($"SALARY").as("min"))
checkAggregateRemoved(df2, false)
checkAnswer(df2, Seq(
Row("alex", 12000.00, 12000.000000, 12000.00),
Row("amy", 10000.00, 10000.000000, 10000.00),
Row("cathy", 9000.00, 9000.000000, 9000.00),
Row("david", 10000.00, 10000.000000, 10000.00),
Row("jen", 12000.00, 12000.000000, 12000.00)))
}

test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") {
val df = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
checkAggregateRemoved(df, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only verifies the aggregates are not completely pushed down. We still need to check if the partial push down works, right?

checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))

val df2 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
.option("numPartitions", "2")
.table("h2.test.employee")
.groupBy($"name")
.agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
checkAggregateRemoved(df2, false)
checkAnswer(df2, Seq(
Row("alex", 12000.00, 12000.000000, 1),
Row("amy", 10000.00, 10000.000000, 1),
Row("cathy", 9000.00, 9000.000000, 1),
Row("david", 10000.00, 10000.000000, 1),
Row("jen", 12000.00, 12000.000000, 1)))
}
}