diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index 3c9626409860..4a1769daa79c 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -211,6 +211,13 @@ the following case-insensitive options:
Specifies kerberos principal name for the JDBC client. If both keytab and principal are defined then Spark tries to do kerberos authentication.
+
+
+ pushDownAggregate |
+
+ The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will NOT push down aggregates to the JDBC data source. Otherwise, if set to true, aggregate will be pushed down to the JDBC data source and thus aggregates will be handled by data source instead of Spark. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates are pushed down if and only if all the aggregates and the related filters can be pushed down.
+ |
+
Note that kerberos authentication with keytab is not always supported by the JDBC driver.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
new file mode 100644
index 000000000000..40ed146114ff
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.sources.Aggregation;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * push down aggregates to the data source.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public interface SupportsPushDownAggregates extends ScanBuilder {
+
+ /**
+ * Pushes down Aggregation to datasource.
+ * The Aggregation can be pushed down only if all the Aggregate Functions can
+ * be pushed down.
+ */
+ void pushAggregation(Aggregation aggregation);
+
+ /**
+ * Returns the aggregation that are pushed to the data source via
+ * {@link #pushAggregation(Aggregation aggregation)}.
+ */
+ Aggregation pushedAggregation();
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 189d21603e70..7cfae63fb057 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -17,11 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@ExpressionDescription(
@@ -44,52 +40,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
// scalastyle:on line.size.limit
-case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
-
- override def nullable: Boolean = false
-
- // Return data type.
- override def dataType: DataType = LongType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
- TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
- s"If you have to call the function $prettyName without arguments, set the legacy " +
- s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
- }
-
- protected lazy val count = AttributeReference("count", LongType, nullable = false)()
-
- override lazy val aggBufferAttributes = count :: Nil
-
- override lazy val initialValues = Seq(
- /* count = */ Literal(0L)
- )
-
- override lazy val mergeExpressions = Seq(
- /* count = */ count.left + count.right
- )
-
- override lazy val evaluateExpression = count
-
- override def defaultResult: Option[Literal] = Option(Literal(0L))
-
- override lazy val updateExpressions = {
- val nullableChildren = children.filter(_.nullable)
- if (nullableChildren.isEmpty) {
- Seq(
- /* count = */ count + 1L
- )
- } else {
- Seq(
- /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
- )
- }
- }
-}
+case class Count(children: Seq[Expression]) extends CountBase(children)
object Count {
def apply(child: Expression): Count = Count(child :: Nil)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala
new file mode 100644
index 000000000000..457c84e9655d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountBase.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+abstract class CountBase(children: Seq[Expression]) extends DeclarativeAggregate {
+
+ override def nullable: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = LongType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
+ s"If you have to call the function $prettyName without arguments, set the legacy " +
+ s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ protected lazy val count = AttributeReference("count", LongType, nullable = false)()
+
+ override lazy val aggBufferAttributes = count :: Nil
+
+ override lazy val initialValues = Seq(
+ /* count = */ Literal(0L)
+ )
+
+ override lazy val mergeExpressions = Seq(
+ /* count = */ count.left + count.right
+ )
+
+ override lazy val evaluateExpression = count
+
+ override def defaultResult: Option[Literal] = Option(Literal(0L))
+
+ override lazy val updateExpressions = {
+ val nullableChildren = children.filter(_.nullable)
+ if (nullableChildren.isEmpty) {
+ Seq(
+ /* count = */ count + 1L
+ )
+ } else {
+ Seq(
+ /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
+ )
+ }
+ }
+}
+
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala
new file mode 100644
index 000000000000..7fc58a14351d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PushDownCount.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.LongType
+
+case class PushDownCount(children: Seq[Expression]) extends CountBase(children) {
+
+ override protected lazy val count =
+ AttributeReference("PushDownCount", LongType, nullable = false)()
+
+ override lazy val updateExpressions = {
+ Seq(
+ // if count is pushed down to Data Source layer, add the count result retrieved from
+ // Data Source
+ /* count = */ count + children.head
+ )
+ }
+}
+
+object PushDownCount {
+ def apply(child: Expression): PushDownCount =
+ PushDownCount(child :: Nil)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala
new file mode 100644
index 000000000000..ded1b9cb52ee
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql.types.DataType
+
+case class Aggregation(aggregateExpressions: Seq[AggregateFunc],
+ groupByExpressions: Seq[String])
+
+abstract class AggregateFunc
+
+case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
+case class Min(column: String, dataType: DataType) extends AggregateFunc
+case class Max(column: String, dataType: DataType) extends AggregateFunc
+case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
+case class Count(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
+
+object Aggregation {
+ // Returns an empty Aggregate
+ def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 38e63d425bb2..33cc083711ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.{BaseRelation, Filter}
+import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
@@ -102,6 +102,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
+ aggregation: Aggregation,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
@@ -132,9 +133,17 @@ case class RowDataSourceScanExec(
val markedFilters = for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
}
+ val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield {
+ s"*$aggregate"
+ }
+ val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield {
+ s"*$groupby"
+ }
Map(
"ReadSchema" -> requiredSchema.catalogString,
- "PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
+ "PushedFilters" -> markedFilters.mkString("[", ", ", "]"),
+ "PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"),
+ "PushedGroupby" -> markedGroupby.mkString("[", ", ", "]"))
}
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a097017222b5..b5bd5b4a647a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{CacheTable, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, UncacheTable}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -358,6 +359,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
+ Aggregation.empty,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
@@ -431,6 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
+ Aggregation.empty,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -453,6 +456,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
+ Aggregation.empty,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -700,6 +704,49 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
}
+ private def columnAsString(e: Expression): String = e match {
+ case AttributeReference(name, _, _, _) => name
+ case Cast(child, _, _) => columnAsString (child)
+ case Add(left, right, _) =>
+ columnAsString(left) + " + " + columnAsString(right)
+ case Subtract(left, right, _) =>
+ columnAsString(left) + " - " + columnAsString(right)
+ case Multiply(left, right, _) =>
+ columnAsString(left) + " * " + columnAsString(right)
+ case Divide(left, right, _) =>
+ columnAsString(left) + " / " + columnAsString(right)
+ case CheckOverflow(child, _, _) => columnAsString (child)
+ case PromotePrecision(child) => columnAsString (child)
+ case _ => ""
+ }
+
+ protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
+ aggregates.aggregateFunction match {
+ case min: aggregate.Min =>
+ val colName = columnAsString(min.child)
+ if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None
+ case max: aggregate.Max =>
+ val colName = columnAsString(max.child)
+ if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None
+ case avg: aggregate.Average =>
+ val colName = columnAsString(avg.child)
+ if (colName.nonEmpty) Some(Avg(colName, avg.dataType, aggregates.isDistinct)) else None
+ case sum: aggregate.Sum =>
+ val colName = columnAsString(sum.child)
+ if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None
+ case count: aggregate.Count =>
+ val columnName = count.children.head match {
+ case Literal(_, _) => "1"
+ case _ => columnAsString(count.children.head)
+ }
+ if (columnName.nonEmpty) {
+ Some(Count(columnName, count.dataType, aggregates.isDistinct))
+ }
+ else None
+ case _ => None
+ }
+ }
+
/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 6e8b7ea67826..8c6518bb4962 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -191,6 +191,9 @@ class JDBCOptions(
// An option to allow/disallow pushing down predicate into JDBC data source
val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean
+ // An option to allow/disallow pushing down aggregate into JDBC data source
+ val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean
+
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
@@ -260,6 +263,7 @@ object JDBCOptions {
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
+ val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 87ca78db59b2..026b36cf2954 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, PreparedStatement, ResultSet}
+import java.util.StringTokenizer
+import scala.collection.mutable.ArrayBuilder
import scala.util.control.NonFatal
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
@@ -133,6 +135,68 @@ object JDBCRDD extends Logging {
})
}
+ private def containsArithmeticOp(col: String): Boolean =
+ col.contains("+") || col.contains("-") || col.contains("*") || col.contains("/")
+
+ def compileAggregates(
+ aggregates: Seq[AggregateFunc],
+ dialect: JdbcDialect): (Array[String], Array[DataType]) = {
+ def quote(colName: String): String = dialect.quoteIdentifier(colName)
+ val aggBuilder = ArrayBuilder.make[String]
+ val dataTypeBuilder = ArrayBuilder.make[DataType]
+ aggregates.map {
+ case Min(column, dataType) =>
+ dataTypeBuilder += dataType
+ if (!containsArithmeticOp(column)) {
+ aggBuilder += s"MIN(${quote(column)})"
+ } else {
+ aggBuilder += s"MIN(${quoteEachCols(column, dialect)})"
+ }
+ case Max(column, dataType) =>
+ dataTypeBuilder += dataType
+ if (!containsArithmeticOp(column)) {
+ aggBuilder += s"MAX(${quote(column)})"
+ } else {
+ aggBuilder += s"MAX(${quoteEachCols(column, dialect)})"
+ }
+ case Sum(column, dataType, isDistinct) =>
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ dataTypeBuilder += dataType
+ if (!containsArithmeticOp(column)) {
+ aggBuilder += s"SUM(${distinct} ${quote(column)})"
+ } else {
+ aggBuilder += s"SUM(${distinct} ${quoteEachCols(column, dialect)})"
+ }
+ case Avg(column, dataType, isDistinct) =>
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ dataTypeBuilder += dataType
+ if (!containsArithmeticOp(column)) {
+ aggBuilder += s"AVG(${distinct} ${quote(column)})"
+ } else {
+ aggBuilder += s"AVG(${distinct} ${quoteEachCols(column, dialect)})"
+ }
+ case Count(column, dataType, isDistinct) =>
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ dataTypeBuilder += dataType
+ val col = if (column.equals("1")) column else quote(column)
+ aggBuilder += s"COUNT(${distinct} $col)"
+ case _ =>
+ }
+ (aggBuilder.result, dataTypeBuilder.result)
+ }
+
+ private def quoteEachCols (column: String, dialect: JdbcDialect): String = {
+ def quote(colName: String): String = dialect.quoteIdentifier(colName)
+ val colsBuilder = ArrayBuilder.make[String]
+ val st = new StringTokenizer(column, "+-*/", true)
+ colsBuilder += quote(st.nextToken().trim)
+ while (st.hasMoreTokens) {
+ colsBuilder += st.nextToken
+ colsBuilder += quote(st.nextToken().trim)
+ }
+ colsBuilder.result.mkString(" ")
+ }
+
/**
* Build and return JDBCRDD from the given information.
*
@@ -152,7 +216,9 @@ object JDBCRDD extends Logging {
requiredColumns: Array[String],
filters: Array[Filter],
parts: Array[Partition],
- options: JDBCOptions): RDD[InternalRow] = {
+ options: JDBCOptions,
+ aggregation: Aggregation = Aggregation.empty)
+ : RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
@@ -164,7 +230,8 @@ object JDBCRDD extends Logging {
filters,
parts,
url,
- options)
+ options,
+ aggregation)
}
}
@@ -181,7 +248,8 @@ private[jdbc] class JDBCRDD(
filters: Array[Filter],
partitions: Array[Partition],
url: String,
- options: JDBCOptions)
+ options: JDBCOptions,
+ aggregation: Aggregation = Aggregation.empty)
extends RDD[InternalRow](sc, Nil) {
/**
@@ -189,15 +257,44 @@ private[jdbc] class JDBCRDD(
*/
override def getPartitions: Array[Partition] = partitions
+ private val (updatedSchema, updatedCol): (StructType, Array[String]) =
+ if (aggregation.aggregateExpressions.isEmpty) {
+ (schema, columns)
+ } else {
+ getAggregateSchemaAndCol
+ }
+
/**
* `columns`, but as a String suitable for injection into a SQL query.
*/
private val columnList: String = {
val sb = new StringBuilder()
- columns.foreach(x => sb.append(",").append(x))
+ updatedCol.foreach(x => sb.append(",").append(x))
if (sb.isEmpty) "1" else sb.substring(1)
}
+ /**
+ * Build the column lists for Aggregates push down:
+ * each of the Aggregates + groupBy columns
+ */
+ private def getAggregateSchemaAndCol(): (StructType, Array[String]) = {
+ var updatedSchema: StructType = new StructType()
+ val (compiledAgg, aggDataType) =
+ JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url))
+ val colDataTypeMap: Map[String, StructField] = columns.zip(schema.fields).toMap
+ val newColsBuilder = ArrayBuilder.make[String]
+ for ((col, dataType) <- compiledAgg.zip(aggDataType)) {
+ newColsBuilder += col
+ updatedSchema = updatedSchema.add(col, dataType)
+ }
+ for (groupBy <- aggregation.groupByExpressions) {
+ val quotedGroupBy = JdbcDialects.get(url).quoteIdentifier(groupBy)
+ newColsBuilder += quotedGroupBy
+ updatedSchema = updatedSchema.add(colDataTypeMap.get(quotedGroupBy).get)
+ }
+ (updatedSchema, newColsBuilder.result)
+ }
+
/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
@@ -221,6 +318,18 @@ private[jdbc] class JDBCRDD(
}
}
+ /**
+ * A GROUP BY clause representing pushed-down grouping columns.
+ */
+ private def getGroupByClause: String = {
+ if (aggregation.groupByExpressions.length > 0) {
+ val quotedColumns = aggregation.groupByExpressions.map(JdbcDialects.get(url).quoteIdentifier)
+ s"GROUP BY ${quotedColumns.mkString(", ")}"
+ } else {
+ ""
+ }
+ }
+
/**
* Runs the SQL query against the JDBC driver.
*
@@ -296,13 +405,14 @@ private[jdbc] class JDBCRDD(
val myWhereClause = getWhereClause(part)
- val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause"
+ val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
+ s" $getGroupByClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
- val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
+ val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, updatedSchema, inputMetrics)
CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 2f1ee0f23d45..97a36de9f8b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -249,6 +249,7 @@ private[sql] case class JDBCRelation(
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
+ with PrunedFilteredAggregateScan
with InsertableRelation {
override def sqlContext: SQLContext = sparkSession.sqlContext
@@ -275,6 +276,21 @@ private[sql] case class JDBCRelation(
jdbcOptions).asInstanceOf[RDD[Row]]
}
+ override def buildScan(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ aggregation: Aggregation): RDD[Row] = {
+ // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
+ JDBCRDD.scanTable(
+ sparkSession.sparkContext,
+ schema,
+ requiredColumns,
+ filters,
+ parts,
+ jdbcOptions,
+ aggregation).asInstanceOf[RDD[Row]]
+ }
+
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 976c7df841dd..607402f9f5b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -86,7 +86,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters,
- relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) =>
+ relation @ DataSourceV2ScanRelation(_,
+ V1ScanWrapper(scan, translated, pushed, aggregation), output)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw new IllegalArgumentException(
@@ -101,6 +102,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
output.toStructType,
translated.toSet,
pushed.toSet,
+ aggregation,
unsafeRowRDD,
v1Relation,
tableIdentifier = None)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 167ba45b888a..6ba64ca5e99f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -20,11 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.Aggregation
import org.apache.spark.sql.types.StructType
object PushDownUtils extends PredicateHelper {
@@ -70,6 +72,43 @@ object PushDownUtils extends PredicateHelper {
}
}
+ /**
+ * Pushes down aggregates to the data source reader
+ *
+ * @return pushed aggregation.
+ */
+ def pushAggregates(
+ scanBuilder: ScanBuilder,
+ aggregates: Seq[AggregateExpression],
+ groupBy: Seq[Expression]): Aggregation = {
+
+ def columnAsString(e: Expression): String = e match {
+ case AttributeReference(name, _, _, _) => name
+ case _ => ""
+ }
+
+ scanBuilder match {
+ case r: SupportsPushDownAggregates =>
+ val translatedAggregates = mutable.ArrayBuffer.empty[sources.AggregateFunc]
+
+ for (aggregateExpr <- aggregates) {
+ val translated = DataSourceStrategy.translateAggregate(aggregateExpr)
+ if (translated.isEmpty) {
+ return Aggregation.empty
+ } else {
+ translatedAggregates += translated.get
+ }
+ }
+ val groupByCols = groupBy.map(columnAsString(_))
+ if (!groupByCols.exists(_.isEmpty)) {
+ r.pushAggregation(Aggregation(translatedAggregates, groupByCols))
+ }
+ r.pushedAggregation
+
+ case _ => Aggregation.empty
+ }
+ }
+
/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index d2180566790a..bbd8fccfbd38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -17,38 +17,132 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
+import org.apache.spark.sql.sources.{AggregateFunc, Aggregation}
import org.apache.spark.sql.types.StructType
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
+
import DataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
- case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
- val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+ case Aggregate(groupingExpressions, resultExpressions, child) =>
+ child match {
+ case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+ val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
- val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
- val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
- normalizedFilters.partition(SubqueryExpression.hasSubquery)
+ val aliasMap = getAliasMap(project)
+ var aggregates = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression =>
+ replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression]
+ }
+ }
+ aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output)
+ .asInstanceOf[Seq[AggregateExpression]]
- // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
- // `postScanFilters` need to be evaluated after the scan.
- // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
- val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
- scanBuilder, normalizedFiltersWithoutSubquery)
- val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
+ val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ expr =>
+ expr.collect {
+ case e: Expression => replaceAlias(e, aliasMap)
+ }
+ }
+ val normalizedGroupingExpressions =
+ DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, relation.output)
+
+ var newFilters = filters
+ aggregates.foreach(agg =>
+ if (agg.filter.nonEmpty) {
+ // handle agg filter the same way as other filters
+ newFilters = newFilters :+ agg.filter.get
+ }
+ )
+
+ val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, newFilters, relation)
+ if (postScanFilters.nonEmpty) {
+ Aggregate(groupingExpressions, resultExpressions, child)
+ } else { // only push down aggregate if all the filers can be push down
+ val aggregation = PushDownUtils.pushAggregates(scanBuilder, aggregates,
+ normalizedGroupingExpressions)
+
+ val (scan, output, normalizedProjects) =
+ processFilterAndColumn(scanBuilder, project, postScanFilters, relation)
+
+ logInfo(
+ s"""
+ |Pushing operators to ${relation.name}
+ |Pushed Filters: ${pushedFilters.mkString(", ")}
+ |Post-Scan Filters: ${postScanFilters.mkString(",")}
+ |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")}
+ |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", ")}
+ |Output: ${output.mkString(", ")}
+ """.stripMargin)
+
+ val wrappedScan = scan match {
+ case v1: V1Scan =>
+ val translated = newFilters.flatMap(DataSourceStrategy.translateFilter(_, true))
+ V1ScanWrapper(v1, translated, pushedFilters, aggregation)
+ case _ => scan
+ }
+
+ if (aggregation.aggregateExpressions.isEmpty) {
+ Aggregate(groupingExpressions, resultExpressions, child)
+ } else {
+ val aggOutputBuilder = ArrayBuilder.make[AttributeReference]
+ for (i <- 0 until aggregates.length) {
+ aggOutputBuilder += AttributeReference(
+ aggregation.aggregateExpressions(i).toString, aggregates(i).dataType)()
+ }
+ groupingExpressions.foreach{
+ case a@AttributeReference(_, _, _, _) => aggOutputBuilder += a
+ case _ =>
+ }
+ val aggOutput = aggOutputBuilder.result
+
+ val r = buildLogicalPlan(aggOutput, relation, wrappedScan, aggOutput,
+ normalizedProjects, postScanFilters)
+ val plan = Aggregate(groupingExpressions, resultExpressions, r)
+
+ var i = 0
+ plan.transformExpressions {
+ case agg: AggregateExpression =>
+ i += 1
+ val aggFunction: aggregate.AggregateFunction = {
+ if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) {
+ aggregate.Max(aggOutput(i - 1))
+ } else if (agg.aggregateFunction.isInstanceOf[aggregate.Min]) {
+ aggregate.Min(aggOutput(i - 1))
+ } else if (agg.aggregateFunction.isInstanceOf[aggregate.Average]) {
+ aggregate.Average(aggOutput(i - 1))
+ } else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) {
+ aggregate.Sum(aggOutput(i - 1))
+ } else if (agg.aggregateFunction.isInstanceOf[aggregate.Count]) {
+ aggregate.PushDownCount(aggOutput(i - 1))
+ } else {
+ agg.aggregateFunction
+ }
+ }
+ agg.copy(aggregateFunction = aggFunction, filter = None)
+ }
+ }
+ }
+
+ case _ => Aggregate(groupingExpressions, resultExpressions, child)
+ }
+ case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
+ val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
+ val (pushedFilters, postScanFilters) = pushDownFilter (scanBuilder, filters, relation)
+ val (scan, output, normalizedProjects) =
+ processFilterAndColumn(scanBuilder, project, postScanFilters, relation)
- val normalizedProjects = DataSourceStrategy
- .normalizeExprs(project, relation.output)
- .asInstanceOf[Seq[NamedExpression]]
- val (scan, output) = PushDownUtils.pruneColumns(
- scanBuilder, relation, normalizedProjects, postScanFilters)
logInfo(
s"""
|Pushing operators to ${relation.name}
@@ -60,31 +154,72 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
val wrappedScan = scan match {
case v1: V1Scan =>
val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true))
- V1ScanWrapper(v1, translated, pushedFilters)
+ V1ScanWrapper(v1, translated, pushedFilters,
+ Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]))
+
case _ => scan
}
- val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output)
+ buildLogicalPlan(project, relation, wrappedScan, output, normalizedProjects, postScanFilters)
+ }
- val projectionOverSchema = ProjectionOverSchema(output.toStructType)
- val projectionFunc = (expr: Expression) => expr transformDown {
- case projectionOverSchema(newExpr) => newExpr
- }
+ private def pushDownFilter(
+ scanBuilder: ScanBuilder,
+ filters: Seq[Expression],
+ relation: DataSourceV2Relation): (Seq[sources.Filter], Seq[Expression]) = {
+ val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
+ val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
+ normalizedFilters.partition(SubqueryExpression.hasSubquery)
- val filterCondition = postScanFilters.reduceLeftOption(And)
- val newFilterCondition = filterCondition.map(projectionFunc)
- val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation)
-
- val withProjection = if (withFilter.output != project) {
- val newProjects = normalizedProjects
- .map(projectionFunc)
- .asInstanceOf[Seq[NamedExpression]]
- Project(newProjects, withFilter)
- } else {
- withFilter
- }
+ // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
+ // `postScanFilters` need to be evaluated after the scan.
+ // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
+ val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
+ scanBuilder, normalizedFiltersWithoutSubquery)
+ val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
+ (pushedFilters, postScanFilters)
+ }
+
+ private def processFilterAndColumn(
+ scanBuilder: ScanBuilder,
+ project: Seq[NamedExpression],
+ postScanFilters: Seq[Expression],
+ relation: DataSourceV2Relation):
+ (Scan, Seq[AttributeReference], Seq[NamedExpression]) = {
+ val normalizedProjects = DataSourceStrategy
+ .normalizeExprs(project, relation.output)
+ .asInstanceOf[Seq[NamedExpression]]
+ val (scan, output) = PushDownUtils.pruneColumns(
+ scanBuilder, relation, normalizedProjects, postScanFilters)
+ (scan, output, normalizedProjects)
+ }
- withProjection
+ private def buildLogicalPlan(
+ project: Seq[NamedExpression],
+ relation: DataSourceV2Relation,
+ wrappedScan: Scan,
+ output: Seq[AttributeReference],
+ normalizedProjects: Seq[NamedExpression],
+ postScanFilters: Seq[Expression]): LogicalPlan = {
+ val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output)
+ val projectionOverSchema = ProjectionOverSchema(output.toStructType)
+ val projectionFunc = (expr: Expression) => expr transformDown {
+ case projectionOverSchema(newExpr) => newExpr
+ }
+
+ val filterCondition = postScanFilters.reduceLeftOption(And)
+ val newFilterCondition = filterCondition.map(projectionFunc)
+ val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation)
+
+ val withProjection = if (withFilter.output != project) {
+ val newProjects = normalizedProjects
+ .map(projectionFunc)
+ .asInstanceOf[Seq[NamedExpression]]
+ Project(newProjects, withFilter)
+ } else {
+ withFilter
+ }
+ withProjection
}
}
@@ -93,6 +228,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
case class V1ScanWrapper(
v1Scan: V1Scan,
translatedFilters: Seq[sources.Filter],
- handledFilters: Seq[sources.Filter]) extends Scan {
+ handledFilters: Seq[sources.Filter],
+ pushedAggregates: sources.Aggregation) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index 860232ba84f3..d8c29aeb1921 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -20,13 +20,14 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
-import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
+import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter, TableScan}
import org.apache.spark.sql.types.StructType
case class JDBCScan(
relation: JDBCRelation,
prunedSchema: StructType,
- pushedFilters: Array[Filter]) extends V1Scan {
+ pushedFilters: Array[Filter],
+ pushedAggregation: Aggregation) extends V1Scan {
override def readSchema(): StructType = prunedSchema
@@ -36,14 +37,15 @@ case class JDBCScan(
override def schema: StructType = prunedSchema
override def needConversion: Boolean = relation.needConversion
override def buildScan(): RDD[Row] = {
- relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters)
+ relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters, pushedAggregation)
}
}.asInstanceOf[T]
}
override def description(): String = {
super.description() + ", prunedSchema: " + seqToString(prunedSchema) +
- ", PushedFilters: " + seqToString(pushedFilters)
+ ", PushedFilters: " + seqToString(pushedFilters) +
+ ", PushedAggegates: " + seqToString(pushedAggregation.aggregateExpressions)
}
private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 270c5b6d92e3..907d57b1a4e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -17,23 +17,26 @@
package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
import org.apache.spark.sql.jdbc.JdbcDialects
-import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.sources.{Aggregation, Filter}
import org.apache.spark.sql.types.StructType
case class JDBCScanBuilder(
session: SparkSession,
schema: StructType,
jdbcOptions: JDBCOptions)
- extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns {
+ extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns
+ with SupportsPushDownAggregates {
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
private var pushedFilter = Array.empty[Filter]
+ private var pushedAggregations = Aggregation.empty
+
private var prunedSchema = schema
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
@@ -49,6 +52,19 @@ case class JDBCScanBuilder(
override def pushedFilters(): Array[Filter] = pushedFilter
+ override def pushAggregation(aggregation: Aggregation): Unit = {
+ if (jdbcOptions.pushDownAggregate) {
+ val dialect = JdbcDialects.get(jdbcOptions.url)
+ // push down if all the aggregates are supported by the underlying Data Source
+ if (JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)._1.length ==
+ aggregation.aggregateExpressions.size) {
+ pushedAggregations = aggregation
+ }
+ }
+ }
+
+ override def pushedAggregation(): Aggregation = pushedAggregations
+
override def pruneColumns(requiredSchema: StructType): Unit = {
// JDBC doesn't support nested column pruning.
// TODO (SPARK-32593): JDBC support nested column and nested column pruning.
@@ -65,6 +81,7 @@ case class JDBCScanBuilder(
val resolver = session.sessionState.conf.resolver
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
- JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter)
+ JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session),
+ prunedSchema, pushedFilter, pushedAggregation)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 63e57c6804e1..f1a3a616595b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -273,6 +273,16 @@ trait PrunedFilteredScan {
def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}
+/**
+ * @since 3.1.0
+ */
+trait PrunedFilteredAggregateScan {
+ def buildScan(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ aggregation: Aggregation): RDD[Row]
+}
+
/**
* A BaseRelation that can be used to insert data into it through the insert method.
* If overwrite in insert method is true, the old data in the relation should be overwritten with
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index a3a3f4728095..f08c88de3d04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{avg, lit, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -41,6 +41,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.h2.url", url)
.set("spark.sql.catalog.h2.driver", "org.h2.Driver")
+ .set("spark.sql.catalog.h2.pushDownAggregate", "true")
private def withConnection[T](f: Connection => T): T = {
val conn = DriverManager.getConnection(url, new Properties())
@@ -64,6 +65,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," +
+ " bonus DOUBLE)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
+ .executeUpdate()
}
}
@@ -109,6 +123,318 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
checkAnswer(df, Row("mary"))
}
+ test("aggregate pushdown with alias") {
+ val df1 = spark.table("h2.test.employee")
+ var query1 = df1.select($"DEPT", $"SALARY".as("value"))
+ .groupBy($"DEPT")
+ .agg(sum($"value").as("total"))
+ .filter($"total" > 1000)
+ // query1.explain(true)
+ checkAnswer(query1, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000)))
+ val decrease = udf { (x: Double, y: Double) => x - y}
+ var query2 = df1.select($"DEPT", decrease($"SALARY", $"BONUS").as("value"), $"SALARY", $"BONUS")
+ .groupBy($"DEPT")
+ .agg(sum($"value"), sum($"SALARY"), sum($"BONUS"))
+ // query2.explain(true)
+ checkAnswer(query2,
+ Seq(Row(1, 16800.00, 19000.00, 2200.00), Row(2, 19500.00, 22000.00, 2500.00),
+ Row(6, 10800, 12000, 1200)))
+
+ val cols = Seq("a", "b", "c", "d")
+ val df2 = sql("select * from h2.test.employee").toDF(cols: _*)
+ val df3 = df2.groupBy().sum("c")
+ // df3.explain(true)
+ checkAnswer(df3, Seq(Row(53000.00)))
+
+ val df4 = df2.groupBy($"a").sum("c")
+ checkAnswer(df4, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000)))
+ }
+
+ test("scan with aggregate push-down") {
+ val df1 = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
+ " group by DEPT")
+ // df1.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Aggregate ['DEPT], [unresolvedalias('MAX('SALARY), None), unresolvedalias('MIN('BONUS), None)]
+ // +- 'Filter ('dept > 0)
+ // +- 'UnresolvedRelation [h2, test, employee], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // max(SALARY): decimal(20,2), min(BONUS): decimal(6,2)
+ // Aggregate [DEPT#253], [max(SALARY#255) AS max(SALARY)#259, min(BONUS#256) AS min(BONUS)#260]
+ // +- Filter (dept#253 > 0)
+ // +- SubqueryAlias h2.test.employee
+ // +- RelationV2[DEPT#253, NAME#254, SALARY#255, BONUS#256] test.employee
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [DEPT#253], [max(Max(SALARY,DecimalType(20,2))#266) AS max(SALARY)#259, min(Min(BONUS,DecimalType(6,2))#267) AS min(BONUS)#260]
+ // +- RelationV2[Max(SALARY,DecimalType(20,2))#266, Min(BONUS,DecimalType(6,2))#267, DEPT#253] test.employee
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[DEPT#253], functions=[max(Max(SALARY,DecimalType(20,2))#266), min(Min(BONUS,DecimalType(6,2))#267)], output=[max(SALARY)#259, min(BONUS)#260])
+ // +- Exchange hashpartitioning(DEPT#253, 5), ENSURE_REQUIREMENTS, [id=#397]
+ // +- HashAggregate(keys=[DEPT#253], functions=[partial_max(Max(SALARY,DecimalType(20,2))#266), partial_min(Min(BONUS,DecimalType(6,2))#267)], output=[DEPT#253, max#270, min#271])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@30437e9c [Max(SALARY,DecimalType(20,2))#266,Min(BONUS,DecimalType(6,2))#267,DEPT#253] PushedAggregates: [*Max(SALARY,DecimalType(20,2)), *Min(BONUS,DecimalType(6,2))], PushedFilters: [IsNotNull(dept), GreaterThan(dept,0)], PushedGroupby: [*DEPT], ReadSchema: struct 0")
+ // df2.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Project [unresolvedalias('MAX('ID), None), unresolvedalias('MIN('ID), None)]
+ // +- 'Filter ('id > 0)
+ // +- 'UnresolvedRelation [h2, test, people], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // max(ID): int, min(ID): int
+ // Aggregate [max(ID#290) AS max(ID)#293, min(ID#290) AS min(ID)#294]
+ // +- Filter (id#290 > 0)
+ // +- SubqueryAlias h2.test.people
+ // +- RelationV2[NAME#289, ID#290] test.people
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [max(Max(ID,IntegerType)#298) AS max(ID)#293, min(Min(ID,IntegerType)#299) AS min(ID)#294]
+ // +- RelationV2[Max(ID,IntegerType)#298, Min(ID,IntegerType)#299] test.people
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[], functions=[max(Max(ID,IntegerType)#298), min(Min(ID,IntegerType)#299)], output=[max(ID)#293, min(ID)#294])
+ // +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#469]
+ // +- HashAggregate(keys=[], functions=[partial_max(Max(ID,IntegerType)#298), partial_min(Min(ID,IntegerType)#299)], output=[max#302, min#303])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@1368e2f7 [Max(ID,IntegerType)#298,Min(ID,IntegerType)#299] PushedAggregates: [*Max(ID,IntegerType), *Min(ID,IntegerType)], PushedFilters: [IsNotNull(id), GreaterThan(id,0)], PushedGroupby: [], ReadSchema: struct
+ // scalastyle:on line.size.limit
+ //
+ // df2.show()
+ // +-------+-------+
+ // |max(ID)|min(ID)|
+ // +-------+-------+
+ // | 2| 1|
+ // +-------+-------+
+ checkAnswer(df2, Seq(Row(2, 1)))
+
+ val df3 = sql("select AVG(BONUS) FROM h2.test.employee")
+ // df3.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Project [unresolvedalias('AVG('BONUS), None)]
+ // +- 'UnresolvedRelation [h2, test, employee], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // avg(BONUS): double
+ // Aggregate [avg(BONUS#69) AS avg(BONUS)#71]
+ // +- SubqueryAlias h2.test.employee
+ // +- RelationV2[DEPT#66, NAME#67, SALARY#68, BONUS#69] test.employee
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [avg(Avg(BONUS,DoubleType,false)#74) AS avg(BONUS)#71]
+ // +- RelationV2[Avg(BONUS,DoubleType,false)#74] test.employee
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[], functions=[avg(Avg(BONUS,DoubleType,false)#74)], output=[avg(BONUS)#71])
+ // +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#143]
+ // +- HashAggregate(keys=[], functions=[partial_avg(Avg(BONUS,DoubleType,false)#74)], output=[sum#77, count#78L])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@6a1d1467 [Avg(BONUS,DoubleType,false)#74] PushedAggregates: [*Avg(BONUS,DoubleType,false)], PushedFilters: [], PushedGroupby: [], ReadSchema: struct
+ // scalastyle:on line.size.limit
+ // df3.show(false)
+ checkAnswer(df3, Seq(Row(1180.0)))
+
+ val df4 = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
+ // df4.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Project [unresolvedalias(('MAX('SALARY) + 1), None)]
+ // +- 'UnresolvedRelation [h2, test, employee], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // (max(SALARY) + 1): decimal(21,2)
+ // Aggregate [CheckOverflow((promote_precision(cast(max(SALARY#345) as decimal(21,2))) + promote_precision(cast(cast(1 as decimal(1,0)) as decimal(21,2)))), DecimalType(21,2), true) AS (max(SALARY) + 1)#348]
+ // +- SubqueryAlias h2.test.employee
+ // +- RelationV2[DEPT#343, NAME#344, SALARY#345, BONUS#346] test.employee
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [CheckOverflow((promote_precision(cast(max(Max(SALARY,DecimalType(20,2))#351) as decimal(21,2))) + 1.00), DecimalType(21,2), true) AS (max(SALARY) + 1)#348]
+ // +- RelationV2[Max(SALARY,DecimalType(20,2))#351] test.employee
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[], functions=[max(Max(SALARY,DecimalType(20,2))#351)], output=[(max(SALARY) + 1)#348])
+ // +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#589]
+ // +- HashAggregate(keys=[], functions=[partial_max(Max(SALARY,DecimalType(20,2))#351)], output=[max#353])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@453439e [Max(SALARY,DecimalType(20,2))#351] PushedAggregates: [*Max(SALARY,DecimalType(20,2))], PushedFilters: [], PushedGroupby: [], ReadSchema: struct
+ // scalastyle:on line.size.limit
+ checkAnswer(df4, Seq(Row(12001)))
+
+ val df5 = sql("select MIN(SALARY), MIN(BONUS), MIN(SALARY) * MIN(BONUS) FROM h2.test.employee")
+ // df5.explain(true)
+ checkAnswer(df5, Seq(Row(9000, 1000, 9000000)))
+
+ val df6 = sql("select MIN(salary), MIN(bonus), SUM(SALARY * BONUS) FROM h2.test.employee")
+ // df6.explain(true)
+ checkAnswer(df6, Seq(Row(9000, 1000, 62600000)))
+
+ val df7 = sql("select BONUS, SUM(SALARY+BONUS), SALARY FROM h2.test.employee" +
+ " GROUP BY SALARY, BONUS")
+ // df7.explain(true)
+ checkAnswer(df7, Seq(Row(1000, 11000, 10000), Row(1200, 26400, 12000),
+ Row(1200, 10200, 9000), Row(1300, 11300, 10000)))
+
+ val df8 = spark.table("h2.test.employee")
+ val sub2 = udf { (x: String) => x.substring(0, 3) }
+ val name = udf { (x: String) => x.matches("cat|dav|amy") }
+ val df9 = df8.select($"SALARY", $"BONUS", sub2($"NAME").as("nsub2"))
+ .filter("SALARY > 100")
+ .filter(name($"nsub2"))
+ .agg(avg($"SALARY").as("avg_salary"))
+ // df9.explain(true)
+ checkAnswer(df9, Seq(Row(9666.666667)))
+
+ val df10 = sql("select SUM(SALARY+BONUS*SALARY+SALARY/BONUS), DEPT FROM h2.test.employee" +
+ " GROUP BY DEPT")
+ // df10.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Aggregate ['DEPT], [unresolvedalias('SUM((('SALARY + ('BONUS * 'SALARY)) + ('SALARY / 'BONUS))), None), 'DEPT]
+ // +- 'UnresolvedRelation [h2, test, employee], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // sum(((SALARY + (BONUS * SALARY)) + (SALARY / BONUS))): decimal(38,9), DEPT: int
+ // Aggregate [DEPT#551], [sum(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(SALARY#553 as decimal(28,4))) + promote_precision(cast(CheckOverflow((promote_precision(cast(BONUS#554 as decimal(20,2))) * promote_precision(cast(SALARY#553 as decimal(20,2)))), DecimalType(27,4), true) as decimal(28,4)))), DecimalType(28,4), true) as decimal(34,9))) + promote_precision(cast(CheckOverflow((promote_precision(cast(SALARY#553 as decimal(20,2))) / promote_precision(cast(BONUS#554 as decimal(20,2)))), DecimalType(29,9), true) as decimal(34,9)))), DecimalType(34,9), true)) AS sum(((SALARY + (BONUS * SALARY)) + (SALARY / BONUS)))#556, DEPT#551]
+ // +- SubqueryAlias h2.test.employee
+ // +- RelationV2[DEPT#551, NAME#552, SALARY#553, BONUS#554] test.employee
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [DEPT#551], [sum(Sum(SALARY + BONUS * SALARY + SALARY / BONUS,DecimalType(38,9),false)#562) AS sum(((SALARY + (BONUS * SALARY)) + (SALARY / BONUS)))#556, DEPT#551]
+ // +- RelationV2[Sum(SALARY + BONUS * SALARY + SALARY / BONUS,DecimalType(38,9),false)#562, DEPT#551] test.employee
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[DEPT#551], functions=[sum(Sum(SALARY + BONUS * SALARY + SALARY / BONUS,DecimalType(38,9),false)#562)], output=[sum(((SALARY + (BONUS * SALARY)) + (SALARY / BONUS)))#556, DEPT#551])
+ // +- Exchange hashpartitioning(DEPT#551, 5), ENSURE_REQUIREMENTS, [id=#917]
+ // +- HashAggregate(keys=[DEPT#551], functions=[partial_sum(Sum(SALARY + BONUS * SALARY + SALARY / BONUS,DecimalType(38,9),false)#562)], output=[DEPT#551, sum#565, isEmpty#566])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@7692c0e9 [Sum(SALARY + BONUS * SALARY + SALARY / BONUS,DecimalType(38,9),false)#562,DEPT#551] PushedAggregates: [*Sum(SALARY + BONUS * SALARY + SALARY / BONUS,DecimalType(38,9),false)], PushedFilters: [], PushedGroupby: [*DEPT], ReadSchema: struct
+ // scalastyle:on line.size.limit
+ // df10.show(true)
+ // +-----------------------------------------------------+----+
+ // |sum(((SALARY + (BONUS * SALARY)) + (SALARY / BONUS)))|DEPT|
+ // +-----------------------------------------------------+----+
+ // | 20819017.500000000| 1|
+ // | 27422017.692307692| 2|
+ // | 14412010.000000000| 6|
+ // +-----------------------------------------------------+----+
+ checkAnswer(df10, Seq(Row(20819017.500000000, 1), Row(27422017.692307692, 2),
+ Row(14412010.000000000, 6)))
+
+ val df11 = sql("select COUNT(*), DEPT FROM h2.test.employee group by DEPT")
+ // df11.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Aggregate ['DEPT], [unresolvedalias('COUNT(1), None), 'DEPT]
+ // +- 'UnresolvedRelation [h2, test, employee], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // count(1): bigint, DEPT: int
+ // Aggregate [DEPT#602], [count(1) AS count(1)#607L, DEPT#602]
+ // +- SubqueryAlias h2.test.employee
+ // +- RelationV2[DEPT#602, NAME#603, SALARY#604, BONUS#605] test.employee
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [DEPT#602], [count(Count(1,LongType,false)#611L) AS count(1)#607L, DEPT#602]
+ // +- RelationV2[Count(1,LongType,false)#611L, DEPT#602] test.employee
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[DEPT#602], functions=[count(Count(1,LongType,false)#611L)], output=[count(1)#607L, DEPT#602])
+ // +- Exchange hashpartitioning(DEPT#602, 5), ENSURE_REQUIREMENTS, [id=#1029]
+ // +- HashAggregate(keys=[DEPT#602], functions=[partial_count(Count(1,LongType,false)#611L)], output=[DEPT#602, count#613L])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@5653429e [Count(1,LongType,false)#611L,DEPT#602] PushedAggregates: [*Count(1,LongType,false)], PushedFilters: [], PushedGroupby: [*DEPT], ReadSchema: struct
+ // scalastyle:on line.size.limit
+ // df11.show(true)
+ // +--------+----+
+ // |count(1)|DEPT|
+ // +--------+----+
+ // | 2| 1|
+ // | 2| 2|
+ // | 1| 6|
+ // +--------+----+
+ checkAnswer(df11, Seq(Row(2, 1), Row(2, 2), Row(1, 6)))
+
+ val df12 = sql("select COUNT(*) FROM h2.test.employee group by DEPT")
+ // df12.explain(true)
+ checkAnswer(df12, Seq(Row(2), Row(2), Row(1)))
+
+ val df13 = sql("select COUNT(*) FROM h2.test.employee")
+ // df13.explain(true)
+ checkAnswer(df13, Seq(Row(5)))
+
+ val df14 = sql("select COUNT(NAME) FROM h2.test.employee group by DEPT")
+ // df14.explain(true)
+ checkAnswer(df14, Seq(Row(2), Row(2), Row(1)))
+
+ val df15 = sql("select COUNT(NAME) FROM h2.test.employee")
+ // df15.explain(true)
+ checkAnswer(df15, Seq(Row(5)))
+
+ val df16 = sql("select COUNT(NAME), DEPT FROM h2.test.employee group by DEPT")
+ // df16.explain(true)
+ checkAnswer(df16, Seq(Row(2, 1), Row(2, 2), Row(1, 6)))
+
+ val df17 = sql("select MAX(SALARY) FILTER (WHERE SALARY > 1000), MIN(BONUS) " +
+ "FROM h2.test.employee where dept > 0 group by DEPT")
+ // df17.explain(true)
+ // scalastyle:off line.size.limit
+ // == Parsed Logical Plan ==
+ // 'Aggregate ['DEPT], [unresolvedalias('MAX('SALARY, ('SALARY > 1000)), None), unresolvedalias('MIN('BONUS), None)]
+ // +- 'Filter ('dept > 0)
+ // +- 'UnresolvedRelation [h2, test, employee], [], false
+ //
+ // == Analyzed Logical Plan ==
+ // max(SALARY) FILTER (WHERE (SALARY > 1000)): decimal(20,2), min(BONUS): decimal(6,2)
+ // Aggregate [DEPT#797], [max(SALARY#799) FILTER (WHERE (cast(SALARY#799 as decimal(20,2)) > cast(cast(1000 as decimal(4,0)) as decimal(20,2)))) AS max(SALARY) FILTER (WHERE (SALARY > 1000))#804, min(BONUS#800) AS min(BONUS)#802]
+ // +- Filter (dept#797 > 0)
+ // +- SubqueryAlias h2.test.employee
+ // +- RelationV2[DEPT#797, NAME#798, SALARY#799, BONUS#800] test.employee
+ //
+ // == Optimized Logical Plan ==
+ // Aggregate [DEPT#797], [max(Max(SALARY,DecimalType(20,2))#810) AS max(SALARY) FILTER (WHERE (SALARY > 1000))#804, min(Min(BONUS,DecimalType(6,2))#811) AS min(BONUS)#802]
+ // +- RelationV2[Max(SALARY,DecimalType(20,2))#810, Min(BONUS,DecimalType(6,2))#811, DEPT#797] test.employee
+ //
+ // == Physical Plan ==
+ // AdaptiveSparkPlan isFinalPlan=false
+ // +- HashAggregate(keys=[DEPT#797], functions=[max(Max(SALARY,DecimalType(20,2))#810), min(Min(BONUS,DecimalType(6,2))#811)], output=[max(SALARY) FILTER (WHERE (SALARY > 1000))#804, min(BONUS)#802])
+ // +- Exchange hashpartitioning(DEPT#797, 5), ENSURE_REQUIREMENTS, [id=#1647]
+ // +- HashAggregate(keys=[DEPT#797], functions=[partial_max(Max(SALARY,DecimalType(20,2))#810), partial_min(Min(BONUS,DecimalType(6,2))#811)], output=[DEPT#797, max#814, min#815])
+ // +- Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@239fdf8f [Max(SALARY,DecimalType(20,2))#810,Min(BONUS,DecimalType(6,2))#811,DEPT#797] PushedAggregates: [*Max(SALARY,DecimalType(20,2)), *Min(BONUS,DecimalType(6,2))], PushedFilters: [IsNotNull(dept), GreaterThan(dept,0), *GreaterThan(SALARY,1000.00)], PushedGroupby: [*DEPT], ReadSchema: struct 1000))|min(BONUS)|
+ // +------------------------------------------+----------+
+ // | 10000.00| 1000.00|
+ // | 12000.00| 1200.00|
+ // | 12000.00| 1200.00|
+ // +------------------------------------------+----------+
+ checkAnswer(df17, Seq(Row(10000.00, 1000.00), Row(12000.00, 1200.00), Row(12000.00, 1200.00)))
+ }
+
+ test("scan with aggregate distinct push-down") {
+ checkAnswer(sql("SELECT SUM(SALARY) FROM h2.test.employee"), Seq(Row(53000)))
+ checkAnswer(sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee"), Seq(Row(31000)))
+ checkAnswer(sql("SELECT AVG(DEPT) FROM h2.test.employee"), Seq(Row(2)))
+ checkAnswer(sql("SELECT AVG(DISTINCT DEPT) FROM h2.test.employee"), Seq(Row(3)))
+ }
+
test("read/write with partition info") {
withTable("h2.test.abc") {
sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people")
@@ -145,7 +471,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
- Seq(Row("test", "people", false), Row("test", "empty_table", false)))
+ Seq(Row("test", "people", false), Row("test", "empty_table", false),
+ Row("test", "employee", false)))
}
test("SQL API: create table as select") {