From 75f8fdc4f70cf60101c511dc62059e42a1d3f04e Mon Sep 17 00:00:00 2001 From: "chang.chen" Date: Tue, 6 Apr 2021 17:49:57 +0800 Subject: [PATCH 1/2] [SPARK-32833][SQL] JDBC V2 Datasource aggregate push down --- .../read/sqlpushdown/SQLStatement.java | 24 ++ .../read/sqlpushdown/SupportsSQL.java | 32 ++ .../read/sqlpushdown/SupportsSQLPushDown.java | 59 ++++ .../execution/datasources/jdbc/JDBCRDD.scala | 57 +--- .../datasources/jdbc/JDBCRelation.scala | 16 + .../datasources/v2/PushDownUtils.scala | 49 ++- .../v2/V2ScanRelationPushDown.scala | 68 +--- .../datasources/v2/jdbc/JDBCScan.scala | 8 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 41 ++- .../v2/jdbc/JDBCTableCatalog.scala | 3 +- .../datasources/v2/pushdown/PushQuery.scala | 312 ++++++++++++++++++ .../v2/pushdown/sql/PushDownAggUtils.scala | 252 ++++++++++++++ .../v2/pushdown/sql/SQLBuilder.scala | 178 ++++++++++ .../v2/pushdown/sql/SingleSQLStatement.scala | 156 +++++++++ .../v2/PushDownOptimizeSuite.scala | 206 ++++++++++++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 205 +++++++++++- 16 files changed, 1538 insertions(+), 128 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SQLStatement.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQL.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQLPushDown.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/PushQuery.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/PushDownAggUtils.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SQLBuilder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SingleSQLStatement.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOptimizeSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SQLStatement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SQLStatement.java new file mode 100644 index 0000000000000..f5d05de786eaa --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SQLStatement.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.read.sqlpushdown; + +import org.apache.spark.annotation.Evolving; + +@Evolving +public interface SQLStatement { + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQL.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQL.java new file mode 100644 index 0000000000000..49a63b62d4222 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQL.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.read.sqlpushdown; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.TableCatalog; + +/** + * A mix-in interface for {@link TableCatalog} to indicate that Data sources whether support SQL or + * not + * + * @since 3.x.x + */ + +@Evolving +public interface SupportsSQL extends TableCatalog { + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQLPushDown.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQLPushDown.java new file mode 100644 index 0000000000000..14f83ffd28377 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/sqlpushdown/SupportsSQLPushDown.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.read.sqlpushdown; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources which support SQL can implement this + * interface to push down SQL to backend and reduce the size of the data to be read. + * + * @since 3.x.x + */ + +@Evolving +public interface SupportsSQLPushDown extends ScanBuilder, + SupportsPushDownRequiredColumns, + SupportsPushDownFilters { + + /** + * Return true if executing a query on them would result in a query issued to multiple partitions. + * Returns false if it would result in a query to a single partition and therefore provides global + * results. + */ + boolean isMultiplePartitionExecution(); + + /** + * Pushes down {@link SQLStatement} to datasource and returns filters that need to be evaluated + * after scanning. + *

+ * Rows should be returned from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. + */ + Filter[] pushStatement(SQLStatement statement, StructType outputSchema); + + /** + * Returns the statement that are pushed to the data source via + * {@link #pushStatement(SQLStatement statement, StructType outputSchema)} + */ + SQLStatement pushedStatement(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 87ca78db59b29..c32308a2565d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.pushdown.sql.SingleSQLStatement import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -152,19 +153,19 @@ object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition], - options: JDBCOptions): RDD[InternalRow] = { - val url = options.url - val dialect = JdbcDialects.get(url) - val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + options: JDBCOptions, + statement: Option[SingleSQLStatement] = None) + : RDD[InternalRow] = { + val pushedStatement = + statement.getOrElse(SingleSQLStatement(requiredColumns, filters, options)) new JDBCRDD( sc, JdbcUtils.createConnectionFactory(options), pruneSchema(schema, requiredColumns), - quotedColumns, - filters, parts, - url, - options) + options.url, + options, + pushedStatement) } } @@ -177,11 +178,10 @@ private[jdbc] class JDBCRDD( sc: SparkContext, getConnection: () => Connection, schema: StructType, - columns: Array[String], - filters: Array[Filter], partitions: Array[Partition], url: String, - options: JDBCOptions) + options: JDBCOptions, + statement: SingleSQLStatement) extends RDD[InternalRow](sc, Nil) { /** @@ -189,38 +189,6 @@ private[jdbc] class JDBCRDD( */ override def getPartitions: Array[Partition] = partitions - /** - * `columns`, but as a String suitable for injection into a SQL query. - */ - private val columnList: String = { - val sb = new StringBuilder() - columns.foreach(x => sb.append(",").append(x)) - if (sb.isEmpty) "1" else sb.substring(1) - } - - /** - * `filters`, but as a WHERE clause suitable for injection into a SQL query. - */ - private val filterWhereClause: String = - filters - .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) - .map(p => s"($p)").mkString(" AND ") - - /** - * A WHERE clause representing both `filters`, if any, and the current partition. - */ - private def getWhereClause(part: JDBCPartition): String = { - if (part.whereClause != null && filterWhereClause.length > 0) { - "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" - } else if (part.whereClause != null) { - "WHERE " + part.whereClause - } else if (filterWhereClause.length > 0) { - "WHERE " + filterWhereClause - } else { - "" - } - } - /** * Runs the SQL query against the JDBC driver. * @@ -294,9 +262,8 @@ private[jdbc] class JDBCRDD( // fully-qualified table name in the SELECT statement. I don't know how to // talk about a table in a completely portable way. - val myWhereClause = getWhereClause(part) + val sqlText = statement.toSQL(part.whereClause) - val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index e6d8819ac29f3..c91d25775511a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkS import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.execution.datasources.v2.pushdown.sql.SingleSQLStatement import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ @@ -291,6 +292,21 @@ private[sql] case class JDBCRelation( jdbcOptions).asInstanceOf[RDD[Row]] } + def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + statement: SingleSQLStatement): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + JDBCRDD.scanTable( + sparkSession.sparkContext, + schema, + requiredColumns, + filters, + parts, + jdbcOptions, + Some(statement)).asInstanceOf[RDD[Row]] + } + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 167ba45b888a3..b6e2e44eb315a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning, SubqueryExpression} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -81,33 +81,58 @@ object PushDownUtils extends PredicateHelper { relation: DataSourceV2Relation, projects: Seq[NamedExpression], filters: Seq[Expression]): (Scan, Seq[AttributeReference]) = { + prunedColumns(scanBuilder, relation, projects, filters) + .map { prunedSchema => + scanBuilder.asInstanceOf[SupportsPushDownRequiredColumns] + .pruneColumns(prunedSchema) + val scan = scanBuilder.build() + scan -> toOutputAttrs(scan.readSchema(), relation)} + .getOrElse(scanBuilder.build() -> relation.output) + } + + def prunedColumns( + scanBuilder: ScanBuilder, + relation: DataSourceV2Relation, + projects: Seq[NamedExpression], + filters: Seq[Expression]): Option[StructType] = { scanBuilder match { - case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled => + case _: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled => val rootFields = SchemaPruning.identifyRootFields(projects, filters) val prunedSchema = if (rootFields.nonEmpty) { SchemaPruning.pruneDataSchema(relation.schema, rootFields) } else { new StructType() } - r.pruneColumns(prunedSchema) - val scan = r.build() - scan -> toOutputAttrs(scan.readSchema(), relation) + Some(prunedSchema) - case r: SupportsPushDownRequiredColumns => + case _: SupportsPushDownRequiredColumns => val exprs = projects ++ filters val requiredColumns = AttributeSet(exprs.flatMap(_.references)) val neededOutput = relation.output.filter(requiredColumns.contains) - r.pruneColumns(neededOutput.toStructType) - val scan = r.build() // always project, in case the relation's output has been updated and doesn't match // the underlying table schema - scan -> toOutputAttrs(scan.readSchema(), relation) - - case _ => scanBuilder.build() -> relation.output + Some(neededOutput.toStructType) + case _ => None } } + def pushDownFilter( + scanBuilder: ScanBuilder, + filters: Seq[Expression], + relation: DataSourceV2Relation): (Seq[sources.Filter], Seq[Expression]) = { + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) + val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( + scanBuilder, normalizedFiltersWithoutSubquery) + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + (pushedFilters, postScanFilters) + } - private def toOutputAttrs( + def toOutputAttrs( schema: StructType, relation: DataSourceV2Relation): Seq[AttributeReference] = { val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d2180566790ac..05b5a95e1cc36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -14,77 +14,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} -import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.{Scan, V1Scan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.v2.pushdown.PushQuery import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType object V2ScanRelationPushDown extends Rule[LogicalPlan] { - import DataSourceV2Implicits._ - - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) - - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) - val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) - - // `pushedFilters` will be pushed down and evaluated in the underlying data sources. - // `postScanFilters` need to be evaluated after the scan. - // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) - val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery - - val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) - .asInstanceOf[Seq[NamedExpression]] - val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) - logInfo( - s""" - |Pushing operators to ${relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} - |Post-Scan Filters: ${postScanFilters.mkString(",")} - |Output: ${output.mkString(", ")} - """.stripMargin) - - val wrappedScan = scan match { - case v1: V1Scan => - val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) - V1ScanWrapper(v1, translated, pushedFilters) - case _ => scan - } - - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) - - val projectionOverSchema = ProjectionOverSchema(output.toStructType) - val projectionFunc = (expr: Expression) => expr transformDown { - case projectionOverSchema(newExpr) => newExpr - } - - val filterCondition = postScanFilters.reduceLeftOption(And) - val newFilterCondition = filterCondition.map(projectionFunc) - val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) - - val withProjection = if (withFilter.output != project) { - val newProjects = normalizedProjects - .map(projectionFunc) - .asInstanceOf[Seq[NamedExpression]] - Project(newProjects, withFilter) - } else { - withFilter - } - - withProjection + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case PushQuery(pushed: PushQuery) => + pushed.push() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 860232ba84f39..a0f09e49ea71e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -20,13 +20,15 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation +import org.apache.spark.sql.execution.datasources.v2.pushdown.sql.SingleSQLStatement import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} import org.apache.spark.sql.types.StructType case class JDBCScan( relation: JDBCRelation, prunedSchema: StructType, - pushedFilters: Array[Filter]) extends V1Scan { + pushedFilters: Array[Filter], + pushedStatement: SingleSQLStatement) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -36,12 +38,14 @@ case class JDBCScan( override def schema: StructType = prunedSchema override def needConversion: Boolean = relation.needConversion override def buildScan(): RDD[Row] = { - relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters) + relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters, + pushedStatement) } }.asInstanceOf[T] } override def description(): String = { + // TODO: fix description() super.description() + ", prunedSchema: " + seqToString(prunedSchema) + ", PushedFilters: " + seqToString(pushedFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 270c5b6d92e32..76d17fa488ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.connector.read.sqlpushdown.{SQLStatement, SupportsSQLPushDown} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} +import org.apache.spark.sql.execution.datasources.v2.pushdown.sql.{SingleCatalystStatement, SingleSQLStatement, SQLBuilder} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -28,12 +30,14 @@ case class JDBCScanBuilder( session: SparkSession, schema: StructType, jdbcOptions: JDBCOptions) - extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + extends ScanBuilder with SupportsSQLPushDown { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis private var pushedFilter = Array.empty[Filter] + private var statement: SingleSQLStatement = _ + private var prunedSchema = schema override def pushFilters(filters: Array[Filter]): Array[Filter] = { @@ -65,6 +69,37 @@ case class JDBCScanBuilder( val resolver = session.sessionState.conf.resolver val timeZoneId = session.sessionState.conf.sessionLocalTimeZone val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) - JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter) + val relationSchema = if (statement != null) { + prunedSchema + } else { + schema + } + JDBCScan(JDBCRelation(relationSchema, parts, jdbcOptions)(session), + prunedSchema, pushedFilter, statement) + } + + private def toSQLStatement(catalystStatement: SingleCatalystStatement): SingleSQLStatement = { + val projects = catalystStatement.projects + val filters = catalystStatement.filters + val groupBy = catalystStatement.groupBy + SingleSQLStatement ( + relation = jdbcOptions.tableOrQuery, + projects = Some(projects.map(SQLBuilder.expressionToSql(_))), + filters = if (filters.isEmpty) None else Some(filters), + groupBy = if (groupBy.isEmpty) None else Some(groupBy.map(SQLBuilder.expressionToSql(_))), + url = Some(jdbcOptions.url) + ) } + + override def isMultiplePartitionExecution: Boolean = true + + override def pushStatement(push: SQLStatement, outputSchema: StructType): Array[Filter] = { + statement = toSQLStatement(push.asInstanceOf[SingleCatalystStatement]) + if (outputSchema != null) { + prunedSchema = outputSchema + } + statement.filters.map(f => pushFilters(f.toArray)).getOrElse(Array.empty) + } + + override def pushedStatement(): SQLStatement = statement } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index a90ab564ddb50..c94c47562ed7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.sqlpushdown.SupportsSQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JDBCRDD, JdbcUtils} import org.apache.spark.sql.internal.SQLConf @@ -33,7 +34,7 @@ import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging { +class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with SupportsSQL with Logging { private var catalogName: String = null private var options: JDBCOptions = _ private var dialect: JdbcDialect = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/PushQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/PushQuery.scala new file mode 100644 index 0000000000000..cd2520655c2ed --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/PushQuery.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.pushdown + +import java.util.Locale + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, And, AttributeReference, Expression, NamedExpression, ProjectionOverSchema, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Count, Max, Min, Sum} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} +import org.apache.spark.sql.connector.read.sqlpushdown.{SupportsSQL, SupportsSQLPushDown} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation, PushDownUtils, V1ScanWrapper} +import org.apache.spark.sql.execution.datasources.v2.pushdown.sql.{PushDownAggUtils, SingleCatalystStatement} +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType + +abstract sealed class PushQuery extends Logging { + def push(): LogicalPlan +} + +class OldPush ( + project: Seq[NamedExpression], + filters: Seq[Expression], + relation: DataSourceV2Relation, + scanBuilder: ScanBuilder) extends PushQuery { + + private[pushdown] lazy val(pushedFilters, postScanFilters) = + PushDownUtils.pushDownFilter(scanBuilder, filters, relation) + + private[pushdown] lazy val normalizedProjects = DataSourceStrategy + .normalizeExprs(project, relation.output) + .asInstanceOf[Seq[NamedExpression]] + + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return the `Scan` instance (since column pruning is the last step of operator pushdown), + * and new output attributes after column pruning. + */ + def pruningColumns(): (Scan, Seq[AttributeReference]) = { + PushDownUtils.pruneColumns( + scanBuilder, relation, normalizedProjects, postScanFilters) + } + + def newScanRelation(): DataSourceV2ScanRelation = { + val (scan, output) = pruningColumns() + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val wrappedScan = scan match { + case v1: V1Scan => + val translated = filters.flatMap( + DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown = true)) + V1ScanWrapper(v1, translated, pushedFilters) + case _ => scan + } + DataSourceV2ScanRelation(relation, wrappedScan, output) + } + + private def buildPushedPlan(scanRelation: DataSourceV2ScanRelation): LogicalPlan = { + val projectionOverSchema = ProjectionOverSchema(scanRelation.output.toStructType) + val projectionFunc = (expr: Expression) => expr transformDown { + case projectionOverSchema(newExpr) => newExpr + } + val filterCondition = postScanFilters.reduceLeftOption(And) + val newFilterCondition = filterCondition.map(projectionFunc) + val withFilter = newFilterCondition.map(logical.Filter(_, scanRelation)).getOrElse(scanRelation) + + val withProjection = if (withFilter.output != project) { + val newProjects = normalizedProjects + .map(projectionFunc) + .asInstanceOf[Seq[NamedExpression]] + Project(newProjects, withFilter) + } else { + withFilter + } + withProjection + } + + override def push(): LogicalPlan = { + buildPushedPlan(newScanRelation()) + } +} + +case class PushScanQuery( + project: Seq[NamedExpression], + filters: Seq[Expression], + relation: DataSourceV2Relation, + scanBuilder: SupportsSQLPushDown) extends OldPush(project, filters, relation, scanBuilder) { + + override def pruningColumns(): (Scan, Seq[AttributeReference]) = { + val prunedSchema = PushDownUtils.prunedColumns( + scanBuilder, relation, normalizedProjects, postScanFilters) + + prunedSchema.map { prunedSchema => + scanBuilder.pruneColumns(prunedSchema) + val output = PushDownUtils.toOutputAttrs(prunedSchema, relation) + val pushStatement = SingleCatalystStatement.of(relation, output, pushedFilters, Seq.empty) + /** + * output schema set by `SupportsPushDownRequiredColumns#pruneColumns` + */ + scanBuilder.pushStatement(pushStatement, null) + scanBuilder.build() -> output + }.getOrElse( scanBuilder.build() -> relation.output) + } +} + +case class PushAggregateQuery( + groupingExpressions: Seq[Expression], + resultExpressions: Seq[NamedExpression], + child: PushQuery) extends PushQuery with AliasHelper { + + private val scanChild = child.asInstanceOf[PushScanQuery] + + /** This iterator automatically increments every time it is used, + * and is for aliasing subqueries. + */ + private final lazy val alias = Iterator.from(0).map(n => s"pag_$n") + private final lazy val aliasMap = getAliasMap(scanChild.project) + + override def push(): LogicalPlan = { + + val aggregateExpressionsToAliases = + PushDownAggUtils.getAggregationToPushedAliasMap(resultExpressions, Some(() => alias.next())) + + val namedGroupingExpressions = + PushDownAggUtils.getNamedGroupingExpressions(groupingExpressions, resultExpressions) + + // make expression out of the tuples + val pushedPartialGroupings = namedGroupingExpressions.map(_._2) + + val pushedPartialAggregates = + PushDownAggUtils.getPushDownNameExpression(resultExpressions ++ pushedPartialGroupings, + aggregateExpressionsToAliases) + + /** This step is separate to keep the input order of the groupingExpressions */ + val rewrittenResultExpressions = PushDownAggUtils.rewriteResultExpressions(resultExpressions, + aggregateExpressionsToAliases, namedGroupingExpressions.toMap) + + val output = pushedPartialAggregates + .map(_.toAttribute) + .asInstanceOf[Seq[AttributeReference]] + + val scanRelation = newScanRelation(pushedPartialAggregates, pushedPartialGroupings, output) + + Aggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = rewrittenResultExpressions, + child = scanRelation) + } + + def newScanRelation( + aggregations: Seq[NamedExpression], + groupBy: Seq[NamedExpression], + output: Seq[AttributeReference]): LogicalPlan = { + val SQLPushDown = scanChild.scanBuilder + + val outputAndProjectMap = output.map{ attr => + if (aliasMap.contains(attr)) { + val newAttr = aliasMap(attr) + newAttr.collectFirst { case a: AttributeReference => a }.get -> newAttr + } else { + attr -> attr + } + } + val newProjects = outputAndProjectMap.map(_._2) + val newOutput = outputAndProjectMap.map(_._1) + + val aggregationsWithoutAlias = aggregations.map { + e => e.transformDown { + case agg: AggregateExpression => replaceAlias(agg, aliasMap) + case reference: AttributeReference => replaceAlias(reference, aliasMap) + } + }.asInstanceOf[Seq[NamedExpression]] + + val groupByWithoutAlias = groupBy.map { + e => e.transformDown { + case reference: AttributeReference => replaceAlias(reference, aliasMap) + } + }.asInstanceOf[Seq[NamedExpression]] + + val pushStatement = SingleCatalystStatement.of(scanChild.relation, + aggregationsWithoutAlias, + scanChild.pushedFilters, + groupByWithoutAlias) + SQLPushDown.pushStatement(pushStatement, StructType.fromAttributes(newOutput)) + + val scan = SQLPushDown.build() match { + case v1: V1Scan => + V1ScanWrapper(v1, Seq.empty[sources.Filter], Seq.empty[sources.Filter]) + case scan => scan + } + + val scanRelation = DataSourceV2ScanRelation(scanChild.relation, scan, newOutput) + + if(newOutput == output) { + scanRelation + } else { + Project(newProjects, scanRelation) + } + } +} + +/** + * [[PushQuery]] currently finds the [[LogicalPlan]] which can be partially executed in an + * individual partition. Extractor for basic SQL queries (not counting subqueries). + * + * The output type is a tuple with the values corresponding to + * `SELECT`, `FROM`, `WHERE`, `GROUP BY` + * + * We inspect the given [[logical.LogicalPlan]] top-down and stop + * at any point where a sub-query would be introduced or if nodes + * need any re-ordering. + * + * The expected order of nodes is: + * - Project / Aggregate + * - Filter + * - Any logical plan as the source relation. + * + * TODO: support push down sql as can as possible in single partition + */ +object PushQuery extends Logging { + + /** + * Determine if the given function is eligible for partial aggregation. + * + * @param aggregateFunction The aggregate function. + * @return `true` if the given aggregate function is not supported for partial aggregation, + * `false` otherwise. + */ + private def nonSupportedAggregateFunction(aggregateFunction: AggregateFunction): Boolean = + aggregateFunction match { + case _: Count => false + case _: Sum => false + case _: Min => false + case _: Max => false + case _: Average => false + case _ => + logWarning("Found an aggregate function" + + s"(${aggregateFunction.prettyName.toUpperCase(Locale.getDefault)})" + + "that could not be pushed down - falling back to normal behavior") + true + } + + private def containNonSupportedAggregateFunction( + aggregateExpressions: Seq[NamedExpression]): Boolean = + aggregateExpressions + .flatMap(expr => expr.collect { case agg: AggregateExpression => agg }) + .exists(agg => agg.isDistinct || nonSupportedAggregateFunction(agg.aggregateFunction)) + + private def containNonSupportProjects( + projects: Seq[NamedExpression]): Boolean = { + projects.flatMap { expr => expr.collect { case u: ScalaUDF => u }} + .nonEmpty + } + private def subqueryPlan(op: LogicalPlan): Boolean = + op match { + case _: logical.Aggregate => true + case _ => false + } + + def unapply(plan: logical.LogicalPlan): Option[PushQuery] = { + import DataSourceV2Implicits._ + + plan match { + + case ScanOperation(project, filters, relation: DataSourceV2Relation) => + relation.table.asReadable.newScanBuilder(relation.options) match { + case down: SupportsSQLPushDown if relation.catalog.exists(_.isInstanceOf[SupportsSQL]) => + Some(PushScanQuery(project, filters, relation, down)) + case builder: ScanBuilder => + Some(new OldPush(project, filters, relation, builder)) + } + + case Aggregate(groupBy, aggExpressions, child) => + unapply(child).flatMap { + case s@PushScanQuery(_, _, _, _) if !subqueryPlan(child) && + !containNonSupportedAggregateFunction(aggExpressions) && + !containNonSupportProjects(s.project) && + s.postScanFilters.isEmpty => + Some(PushAggregateQuery(groupBy, aggExpressions, s)) + case _ => None + } + + case _ => None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/PushDownAggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/PushDownAggUtils.scala new file mode 100644 index 0000000000000..7bd1de70cd41a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/PushDownAggUtils.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.pushdown.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.{DecimalType, DoubleType} + + +object PushDownAggUtils { + /** + * A single aggregate expression might appear multiple times in resultExpressions. + *

+ * In order to avoid evaluating an individual aggregate function multiple times, we'll + * build a seq of the distinct aggregate expressions and build a function which can + * be used to re-write expressions so that they reference the single copy of the + * aggregate function which actually gets computed. + */ + def getAggregationToPushedAliasMap( + resultExpressions: Seq[NamedExpression], + createAliasName: Option[() => String] = None) + : Map[AggregateExpression, Seq[Alias]] = { + val (nonDistinctAggregateExpressions, aggregateExpressionsToNamedExpressions) = + resultExpressions.foldLeft( + Seq.empty[AggregateExpression] -> Map.empty[AggregateExpression, Set[NamedExpression]]) { + case ((aggExpressions, aggExpressionsToNamedExpressions), current) => + val aggregateExpressionsInCurrent = + current.collect { case a: AggregateExpression => a } + /** + * We keep track of the outermost [[NamedExpression]] referencing the + * [[AggregateExpression]] to always have distinct names for the named pushdown + * expressions. + */ + val updatedMapping = + aggregateExpressionsInCurrent.foldLeft(aggExpressionsToNamedExpressions) { + case (mapping, expression) => + val currentMapping = mapping.getOrElse(expression, Set.empty) + mapping.updated(expression, currentMapping + current) + } + (aggExpressions ++ aggregateExpressionsInCurrent) -> updatedMapping + } + val aggregateExpressions = nonDistinctAggregateExpressions.distinct + + /** + * We split and rewrite the given aggregate expressions to partial aggregate expressions + * and keep track of the original aggregate expression for later referencing. + */ + val aggregateExpressionsToAliases: Map[AggregateExpression, Seq[Alias]] = + aggregateExpressions.map { agg => + agg -> rewriteAggregateExpressionsToPartial(agg, + aggregateExpressionsToNamedExpressions(agg), createAliasName) + }.toMap + + aggregateExpressionsToAliases + } + + def getNamedGroupingExpressions( + groupingExpressions: Seq[Expression], + resultExpressions: Seq[NamedExpression]): Seq[(Expression, NamedExpression)] = { + groupingExpressions.map { + case ne: NamedExpression => ne -> ne + /** If the expression is not a NamedExpressions, we add an alias. + * So, when we generate the result of the operator, the Aggregate Operator + * can directly get the Seq of attributes representing the grouping expressions. + */ + case other => + val existingAlias = resultExpressions.find({ + case Alias(aliasChild, aliasName) => aliasChild == other + case _ => false + }) + + // it could be that there is already an alias, so do not "double alias" + val mappedExpression = existingAlias match { + case Some(alias) => alias.toAttribute + case None => Alias(other, other.toString)() + } + other -> mappedExpression + } + } + + /** + * Since for pushdown only [[NamedExpression]]s are allowed, we do the following: + * + * - We extract [[Attribute]]s that can be pushed down straight. + * + * - For [[Expression]]s with [[AggregateExpression]]s in them, we can only push down the + * [[AggregateExpression]]s and leave the [[Expression]] for evaluation on spark side. + * If the [[Expression]] does not contain any [[AggregateExpression]], we can also directly + * push it down. + */ + def getPushDownNameExpression( + resultExpressions: Seq[NamedExpression], + aggregateExpressionsToAliases: Map[AggregateExpression, Seq[Alias]]) + : Seq[NamedExpression] = { + + val pushdownExpressions = resultExpressions.flatMap { + case attr: Attribute => Seq(attr) + case Alias(attr: Attribute, _) => Seq(attr) + case alias@Alias(expression, _) => + /** + * If the collected sequence of [[AggregateExpression]]s is empty then there + * is no dependency of a regular [[Expression]] to a distributed value computed by + * an [[AggregateExpression]] and as such we can push it down. Otherwise, we just push + * down the [[AggregateExpression]]s and apply the other [[Expression]]s via the + * resultExpressions. + */ + val aggs = expression.collect { + case agg: AggregateExpression => agg + } + val nonEmptyExprs = if (aggs.isEmpty) { + Seq(alias) + } else { + aggs + } + nonEmptyExprs + case _ => Seq.empty + }.distinct + + /** + * With this step, we replace the pushdownExpressions with the corresponding + * [[NamedExpression]]s that we can continue to work on. Regular [[NamedExpression]]s are + * 'replaced' by themselves whereas the [[AggregateExpression]]s are replaced by their + * partial versions hidden behind [[Alias]]es. + */ + pushdownExpressions.flatMap { + case agg: AggregateExpression => aggregateExpressionsToAliases(agg) + case namedExpression: NamedExpression => Seq(namedExpression) + } + } + + /** + * This method rewrites an [[AggregateExpression]] to the corresponding partial ones. + * For instance an [[Average]] is rewritten to a [[Sum]] and a [[Count]]. + * + * @param aggregateExpression [[AggregateExpression]] to rewrite. + * @return A sequence of [[Alias]]es that represent the split up [[AggregateExpression]]. + */ + def rewriteAggregateExpressionsToPartial( + aggregateExpression: AggregateExpression, + outerNamedExpressions: Set[NamedExpression], + createAliasName: Option[() => String] = None): Seq[Alias] = { + val outerName = outerNamedExpressions.map(_.name).toSeq.sorted.mkString("", "_", "_") + val inputBuffers = aggregateExpression.aggregateFunction.inputAggBufferAttributes + aggregateExpression.aggregateFunction match { + case avg: Average => + // two: sum and count + val Seq(sumAlias, countAlias, _*) = inputBuffers + val typedChild = avg.child.dataType match { + case DoubleType | DecimalType.Fixed(_, _) => avg.child + case _ => Cast(avg.child, DoubleType) + } + val sumExpression = // sum + AggregateExpression(Sum(typedChild), mode = Partial, aggregateExpression.isDistinct) + val countExpression = { // count + AggregateExpression(Count(avg.child), mode = Partial, aggregateExpression.isDistinct) + } + val suffix = createAliasName.map(_.apply()) + val sumName = suffix.map("sum_" + _).getOrElse(outerName + sumAlias.name) + val cntName = suffix.map("count_" + _ ).getOrElse(outerName + countAlias.name) + Seq( + referenceAs(sumName, sumAlias, sumExpression), + referenceAs(cntName, countAlias, countExpression)) + + case Count(_) | Sum(_) | Max(_) | Min(_) => + inputBuffers.map { ref => + val aliasName = createAliasName.map(_.apply()).getOrElse(outerName + ref.name) + referenceAs(aliasName, ref, aggregateExpression.copy(mode = Partial)) + } + + case _ => throw new RuntimeException("Approached rewrite with unsupported expression") + } + } + + /** + * References a given [[Expression]] as an [[Alias]] of the given [[Attribute]]. + * + * @param attribute The [[Attribute]] to create the [[Alias]] reference of. + * @param expression The [[Expression]] to reference as the given [[Attribute]]. + * @return An [[Alias]] of the [[Expression]] referenced as the [[Attribute]]. + */ + private def referenceAs(name: String, attribute: Attribute, expression: Expression): Alias = { + Alias(expression, name)(attribute.exprId, attribute.qualifier, Some(attribute.metadata)) + } + + def rewriteResultExpressions( + resultExpressions: Seq[NamedExpression], + pushedAggregateMap: Map[AggregateExpression, Seq[Alias]], + groupExpressionMap: Map[Expression, NamedExpression]): Seq[NamedExpression] = + resultExpressions.map { + rewriteResultExpression(_, pushedAggregateMap, groupExpressionMap) + } + + private def rewriteResultExpression( + resultExpression: NamedExpression, + pushedAggregateMap: Map[AggregateExpression, Seq[Alias]], + groupExpressionMap: Map[Expression, NamedExpression]): NamedExpression = { + resultExpression.transformDown { + case old@Alias(l@AggregateExpression(avg: Average, _, _, _, _), _) => + val as = pushedAggregateMap(l) + val sum = as.head.toAttribute + val count = as.last.toAttribute + val average = avg.evaluateExpression.transformDown { + case a: AttributeReference => + a.name match { + case "sum" => Sum(sum).toAggregateExpression() + case "count" => Sum(count).toAggregateExpression() + case _ => a + } + } + old.copy(child = average)( + exprId = old.exprId, + qualifier = old.qualifier, + explicitMetadata = old.explicitMetadata, + nonInheritableMetadataKeys = old.nonInheritableMetadataKeys) + case a@AggregateExpression(agg, _, _, _, _) if pushedAggregateMap.contains(a) => + val x = pushedAggregateMap(a).head + val newAgg = agg match { + case _: Max => Max(x.toAttribute) + case _: Min => Min(x.toAttribute) + case _: Sum => Sum(x.toAttribute) + case _: Count => Sum(x.toAttribute) + case _ => throw new UnsupportedOperationException() + } + a.copy(aggregateFunction = newAgg) + case expression => + + /** + * Since we're using `namedGroupingAttributes` to extract the grouping key + * columns, we need to replace grouping key expressions with their corresponding + * attributes. We do not rely on the equality check at here since attributes may + * differ cosmetically. Instead, we use semanticEquals. + */ + groupExpressionMap.collectFirst { + case (grpExpr, ne) if grpExpr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SQLBuilder.scala new file mode 100644 index 0000000000000..25a02ff746028 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SQLBuilder.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.pushdown.sql + +import java.math.BigInteger +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.{analysis, expressions => expr} +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, BinaryComparison, BinaryExpression, CheckOverflow, Expression, Literal, PromotePrecision} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * SQL builder class. + */ +object SQLBuilder { + + private def longToReadableTimestamp(t: Long): String = { + throw new UnsupportedOperationException + // DateTimeUtils.timestampToString(t) + "." + + // "%07d".format(DateTimeUtils.toJavaTimestamp(t).getNanos()/100) + } + + protected def formatAttributeWithQualifiers(qualifiers: Seq[String], name: String): String = + (qualifiers :+ name).map({ s => s""""$s"""" }).mkString(".") + + protected def literalToSql(value: Any): String = value match { + case s: String => s"'$s'" + case s: UTF8String => s"'$s'" + case b: Byte => s"$b" + case i: Int => s"$i" + case l: Long => s"$l" + case f: Float => s"$f" + case d: Double => s"$d" + case b: Boolean => s"$b" + case bi: BigInteger => s"$bi" + case t: Timestamp => s"TO_TIMESTAMP('$t')" + case d: Date => s"TO_DATE('$d')" + case null => "NULL" + case other => other.toString + } + + def typeToSql(sparkType: DataType): String = sparkType match { + case `StringType` => "VARCHAR(*)" + case `IntegerType` => "INTEGER" + case `ByteType` => "TINYINT" + case `ShortType` => "SMALLINT" + case `LongType` => "BIGINT" + case `FloatType` => "FLOAT" + case `DoubleType` => "DOUBLE" + case DecimalType.Fixed(precision, scale) => s"DECIMAL($precision,$scale)" + case `DateType` => "DATE" + case `BooleanType` => "BOOLEAN" + case `TimestampType` => "TIMESTAMP" + case _ => + throw new IllegalArgumentException(s"Type $sparkType cannot be converted to SQL type") + } + + private def toUnderscoreUpper(str: String): String = { + var result: String = str(0).toUpper.toString + for (i <- 1 until str.length) { + if (str(i-1).isLower && str(i).isUpper) { + result += '_' + } + result += str(i).toUpper + } + result + } + + private def generalExpressionToSql(expression: expr.Expression): String = { + val clazz = expression.getClass + val name = try { + clazz.getDeclaredMethod("prettyName").invoke(expression).asInstanceOf[String] + } catch { + case _: NoSuchMethodException => + toUnderscoreUpper(clazz.getSimpleName) + } + val children = expression.children + val childStr = children.map(expressionToSql).mkString(", ") + s"$name($childStr)" + } + + /** + * Convenience functions to take several expressions + */ + def expressionsToSql(expressions: Seq[expr.Expression], delimiter: String = " "): String = { + expressions.map(expressionToSql).reduceLeft((x, y) => x + delimiter + y) + } + + def expressionToSql(expression: expr.Expression): String = + expression match { + case expr.And(left, right) => s"(${expressionToSql(left)} AND ${expressionToSql(right)})" + case expr.Or(left, right) => s"(${expressionToSql(left)} OR ${expressionToSql(right)})" + case expr.Remainder(child, div, _) => + s"MOD(${expressionToSql(child)}, ${expressionToSql(div)})" + case expr.UnaryMinus(child, _) => s"-(${expressionToSql(child)})" + case expr.IsNull(child) => s"${expressionToSql(child)} IS NULL" + case expr.IsNotNull(child) => s"${expressionToSql(child)} IS NOT NULL" + case expr.Like(left, right, _) => s"${expressionToSql(left)} LIKE ${expressionToSql(right)}" + // TODO: case expr.SortOrder(child,direction) => + // val sortDirection = if (direction == Ascending) "ASC" else "DESC" + // s"${expressionToSql(child)} $sortDirection" + // in Spark 1.5 timestamps are longs and processed internally, however we have to + // convert that to TO_TIMESTAMP() + case t@Literal(_, dataType) if dataType.equals(TimestampType) => + s"TO_TIMESTAMP('${longToReadableTimestamp(t.value.asInstanceOf[Long])}')" + case expr.Literal(value, _) => literalToSql(value) + case expr.Cast(child, dataType, _) => + s"CAST(${expressionToSql(child)} AS ${typeToSql(dataType)})" + // TODO work on that, for SPark 1.6 + // case expr.CountDistinct(children) => s"COUNT(DISTINCT ${expressionsToSql(children, ",")})" + case expr.aggregate.AggregateExpression(aggFunc, _, _, _, _) + => s"${aggFunc.prettyName}(${expressionsToSql(aggFunc.children, ",")})" + case expr.Coalesce(children) => s"COALESCE(${expressionsToSql(children, ",")})" + case expr.DayOfMonth(date) => s"EXTRACT(DAY FROM ${expressionToSql(date)})" + case expr.Month(date) => s"EXTRACT(MONTH FROM ${expressionToSql(date)})" + case expr.Year(date) => s"EXTRACT(YEAR FROM ${expressionToSql(date)})" + case expr.Hour(date, _) => s"EXTRACT(HOUR FROM ${expressionToSql(date)})" + case expr.Minute(date, _) => s"EXTRACT(MINUTE FROM ${expressionToSql(date)})" + case expr.Second(date, _) => s"EXTRACT(SECOND FROM ${expressionToSql(date)})" + case expr.CurrentDate(_) => s"CURRENT_DATE()" + case expr.Pow(left, right) => s"POWER(${expressionToSql(left)}, ${expressionToSql(right)})" + case expr.Substring(str, pos, len) => + s"SUBSTRING(${expressionToSql(str)}, $pos, $len)" + // TODO work on that, for SPark 1.6 + // case expr.Average(child) => s"AVG(${expressionToSql(child)})" + case expr.In(value, list) => + s"${expressionToSql(value)} IN (${list.map(expressionToSql).mkString(", ")})" + case expr.InSet(value, hset) => + s"${expressionToSql(value)} IN (${hset.map(literalToSql).mkString(", ")})" + case a@expr.Alias(child, name) => + s"""${expressionToSql(child)} AS "$name"""" + case a@expr.AttributeReference(name, _, _, _) => s""""$name"""" + // formatAttributeWithQualifiers(a.qualifier, name) + case analysis.UnresolvedAttribute(name) => + formatAttributeWithQualifiers(name.reverse.tail.reverse, name.last) + case _: analysis.Star => "*" + case BinarySymbolExpression(left, symbol, right) => + s"(${expressionToSql(left)} $symbol ${expressionToSql(right)})" + case CheckOverflow(child, _, _) => expressionToSql(child) + case PromotePrecision(child) => expressionToSql(child) + case x => + generalExpressionToSql(x) + } +} + +// TODO optimize this. maybe we can substitute it completely with its logic. +object BinarySymbolExpression { + def isBinaryExpressionWithSymbol(be: BinaryExpression): Boolean = + be.isInstanceOf[BinaryArithmetic] || be.isInstanceOf[BinaryComparison] + + def getBinaryExpressionSymbol(be: BinaryExpression): String = + be match { + case be: BinaryComparison => be.symbol + case be: BinaryArithmetic => be.symbol + case _ => sys.error(s"${be.getClass.getName} has no symbol attribute") + } + + def unapply(any: Any): Option[(Expression, String, Expression)] = any match { + case be: BinaryExpression if isBinaryExpressionWithSymbol(be) => + Some(be.left, getBinaryExpressionSymbol(be), be.right) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SingleSQLStatement.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SingleSQLStatement.scala new file mode 100644 index 0000000000000..5314e2c4b1fa4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/pushdown/sql/SingleSQLStatement.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.pushdown.sql + +import java.util.Locale + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.connector.read.sqlpushdown.SQLStatement +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.Filter + +/** + * Builds a pushed `SELECT` query with optional WHERE and GROUP BY clauses. + * + * @param relation Table name, join clause or subquery for the FROM clause. + * @param projects List of fields for projection as strings. + * @param filters List of filters for the WHERE clause (can be empty). + * @param groupBy List if expressions for the GROUP BY clause (can be empty). + */ +case class SingleSQLStatement( + relation: String, + projects: Option[Seq[String]], + filters: Option[Seq[sources.Filter]], + groupBy: Option[Seq[String]], + url: Option[String] = None) extends SQLStatement { + + /** + * `columns`, but as a String suitable for injection into a SQL query. + * + * The optimizer sometimes does not report any fields (since no specific is required by + * the query (usually a nested select), thus we add the group by clauses as fields + */ + lazy val columnList: String = projects + .map(_.mkString(", ")) + .getOrElse(groupBy.map(_.mkString(", ")).getOrElse("1")) + + /** + * `filters`, but as a WHERE clause suitable for injection into a SQL query. + */ + lazy val filterWhereClause: String = + filters.getOrElse(Seq.empty) + .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url.getOrElse("Unknown URL")))) + .map(p => s"($p)").mkString(" AND ") + + lazy val groupByStr: String = groupBy.map(g => s" GROUP BY ${g.mkString(", ")}").getOrElse("") + + /** + * A WHERE clause representing both `filters`, if any, and the current partition. + */ + private def getWhereClause(extraFilter: String): String = { + if (extraFilter != null && filterWhereClause.nonEmpty) { + "WHERE " + s"($filterWhereClause)" + " AND " + s"($extraFilter)" + } else if (extraFilter != null) { + "WHERE " + extraFilter + } else if (filterWhereClause.nonEmpty) { + "WHERE " + filterWhereClause + } else { + "" + } + } + + def toSQL(extraFilter: String = null): String = { + val myWhereClause = getWhereClause(extraFilter) + s"SELECT $columnList FROM $relation $myWhereClause $groupByStr" + } +} + +object SingleSQLStatement { + def apply( + projects: Array[String], + filters: Array[Filter], + jdbcOptions: JDBCOptions): SingleSQLStatement = { + val url = jdbcOptions.url + val dialect = JdbcDialects.get(url) + val newProjects: Option[Seq[String]] = if (projects.isEmpty) { + None + } else { + Some(projects.map(colName => dialect.quoteIdentifier(colName))) + } + + SingleSQLStatement( + relation = jdbcOptions.tableOrQuery, + projects = newProjects, + filters = if (filters.isEmpty) None else Some(filters), + groupBy = None, + url = Some(url) + ) + } +} + +/** + * It's `ScanBuilder`'s duty to translate [[SingleCatalystStatement]] into [[SingleSQLStatement]] + */ +case class SingleCatalystStatement( + relation: DataSourceV2Relation, + projects: Seq[NamedExpression], + filters: Seq[sources.Filter], + groupBy: Seq[NamedExpression]) extends SQLStatement { +} + +object SingleCatalystStatement { + // TODO: get from configuration + val isCaseSensitive = false + + private def getColumnName(columnName: String, caseSensitive: Boolean): String = { + if (caseSensitive) { + columnName + } else { + columnName.toUpperCase(Locale.ROOT) + } + } + + private def verifyPushedColumn( + nameExpr: NamedExpression, + colSet: Set[String]): Unit = { + nameExpr + .collect{ case att@AttributeReference(_, _, _, _) => att } + .foreach{ attr => + if (!colSet.contains(getColumnName(attr.name, isCaseSensitive))) { + // TODO: report exception + throw new UnsupportedOperationException + } + } + } + + def of( + relation: DataSourceV2Relation, + projects: Seq[NamedExpression], + filters: Seq[sources.Filter], + groupBy: Seq[NamedExpression] + ): SingleCatalystStatement = { + val schemaSet = relation.schema + .map(s => getColumnName(s.name, isCaseSensitive)) + .toSet + projects.foreach(verifyPushedColumn(_, schemaSet)) + groupBy.foreach(verifyPushedColumn(_, schemaSet)) + SingleCatalystStatement(relation, projects, filters, groupBy) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOptimizeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOptimizeSuite.scala new file mode 100644 index 0000000000000..6337cfd91c927 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOptimizeSuite.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, SupportsRead, Table, TableCatalog} +import org.apache.spark.sql.connector.read.{ScanBuilder, V1Scan} +import org.apache.spark.sql.connector.read.sqlpushdown.{SupportsSQL, SupportsSQLPushDown} +import org.apache.spark.sql.execution.datasources.v2.pushdown.PushQuery +import org.apache.spark.sql.execution.datasources.v2.pushdown.sql.{SingleCatalystStatement, SingleSQLStatement, SQLBuilder} +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Test only for get internal pushed statement + */ +trait UTOnlyHasPushedStatement { + def pushedStatement(): SingleSQLStatement +} + +abstract class MockCatalog extends TableCatalog with SupportsNamespaces with SupportsSQL +abstract class MockSQLTable extends Table with SupportsRead +abstract class MockScanBuilder extends ScanBuilder with SupportsSQLPushDown + +abstract class MockScan extends V1Scan with UTOnlyHasPushedStatement + +class PushDownOptimizeSuite extends PlanTest { + + private val emptyMap = CaseInsensitiveStringMap.empty + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches: Seq[Batch] = Batch("Push-Down", Once, + V2ScanRelationPushDown) :: Nil + } + + private def toSQLStatement(catalystStatement: SingleCatalystStatement): SingleSQLStatement = { + val projects = catalystStatement.projects + val filters = catalystStatement.filters + val groupBy = catalystStatement.groupBy + SingleSQLStatement ( + relation = catalystStatement.relation.name, + projects = Some(projects.map(SQLBuilder.expressionToSql(_))), + filters = if (filters.isEmpty) None else Some(filters), + groupBy = if (groupBy.isEmpty) None else Some(groupBy.map(SQLBuilder.expressionToSql(_))), + url = None + ) + } + private def prepareV2Relation(): DataSourceV2Relation = { + val catalog = mock(classOf[MockCatalog]) + val scan = mock(classOf[MockScan]) + val scanBuilder = mock(classOf[MockScanBuilder]) + when(scanBuilder.build()).thenReturn(scan) + when(scanBuilder.pushStatement(any[SingleSQLStatement], any[StructType])) + .thenAnswer { invocationOnMock => + val singleSQLStatement = { + toSQLStatement(invocationOnMock.getArguments()(0).asInstanceOf[SingleCatalystStatement]) + } + when(scan.pushedStatement()).thenReturn(singleSQLStatement) + when(scanBuilder.pushedStatement()).thenReturn(singleSQLStatement) + null + } + when(scanBuilder.pushFilters(any[Array[sources.Filter]])).thenAnswer { invocationOnMock => + val filters = invocationOnMock.getArguments()(0).asInstanceOf[Array[sources.Filter]] + when(scanBuilder.pushedFilters()).thenReturn(filters) + Array.empty[sources.Filter] + } + when(scanBuilder.pruneColumns(any[StructType])).thenAnswer { invocationOnMock => + val schema = + invocationOnMock.getArguments()(0).asInstanceOf[StructType] + when(scan.readSchema()).thenReturn(schema) + } + + lazy val schema = StructType.fromAttributes('a.int::'b.int::'c.int::'t1.string :: Nil) + val table = mock(classOf[MockSQLTable]) + when(table.schema()).thenReturn(schema) + when(table.name()).thenReturn("DB1.Table1") + when(table.newScanBuilder(any[CaseInsensitiveStringMap])).thenReturn(scanBuilder) + + DataSourceV2Relation.create(table, Some(catalog), None, emptyMap) + } + + private def check(plan: LogicalPlan): Unit = { + val dsv2 = plan.find(_.isInstanceOf[DataSourceV2ScanRelation]) + .map(_.asInstanceOf[DataSourceV2ScanRelation]) + .orNull + + assert(dsv2 != null) + assert(dsv2.scan != null) + assert(dsv2.scan.isInstanceOf[V1ScanWrapper]) + + val statement = dsv2.scan.asInstanceOf[V1ScanWrapper] + .v1Scan.asInstanceOf[MockScan].pushedStatement() + assert(statement != null) + assertResult(dsv2.output.size)(statement.projects.get.size) + logWarning(plan.toString) + logWarning(statement.toSQL()) + assert(plan.missingInput.isEmpty) + assert(plan.resolved) + assert(LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) + } + + test("mock with only aggregation") { + val relationV2 = prepareV2Relation() + val sumIf = If('a > Literal(10), Literal(1), Literal(0)) + val query = relationV2.groupBy('t1.attr)('t1, sum(sumIf), avg('a)).analyze + + val plan = query match { + case PushQuery(a: PushQuery) => + Some(a.push()) + case _ => None + } + assert(plan.isDefined) + check(plan.get) + } + + test("mock with only filter") { + val relationV2 = prepareV2Relation() + val query = + relationV2.where('t1.attr =!= "xxx" && 'a.attr < 10 ).select('b.attr, 'c.attr).analyze + + val plan = query match { + case PushQuery(a: PushQuery) => + a.push() + case _ => null + } + assert(plan != null) + check(plan) + + } + + test("mock with group by and filter") { + val relationV2 = prepareV2Relation() + val query = relationV2.where('t1.attr =!= "xxx" && 'a.attr < 10 ) + .groupBy('t1.attr)('t1, sum('b), avg('a)).analyze + + val plan = query match { + case PushQuery(a: PushQuery) => + a.push() + case _ => null + } + assert(plan != null) + check(plan) + } + + test("mock with group by expressions not in aggregation expression ") { + val relationV2 = prepareV2Relation() + val query = relationV2.groupBy('t1.attr)(max('b), sum('c), avg('a)).analyze + + val plan = query match { + case PushQuery(a: PushQuery) => + a.push() + } + check(plan) + } + + test("mock with max(a) + 1") { + val relationV2 = prepareV2Relation() + val query = relationV2.groupBy('t1.attr)(max('b) + 1).analyze + val plan = query match { + case PushQuery(a: PushQuery) => + a.push() + } + check(plan) + } + + test("mock with aggregation with alias") { + val relationV2 = prepareV2Relation() + val query = relationV2 + .select('t1.as("g1"), 'b.as("m1")) + .groupBy('g1.attr)(max('m1) + 1).analyze + val plan = query match { + case PushQuery(a: PushQuery) => + a.push() + } + check(plan) + } + + ignore("mock with Non-correlated subquery") { + val relationV2 = prepareV2Relation() + relationV2.groupBy()(max('b)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a3a3f47280952..87fcdd0751a97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{avg, lit, sum, udf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -64,6 +64,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { .executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + + " bonus NUMERIC(6, 2))").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() } } @@ -109,6 +122,193 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { checkAnswer(df, Row("mary")) } + test("aggregate pushdown with alias") { + val df1 = spark.table("h2.test.employee") + val query1 = df1.select($"DEPT", $"SALARY".as("value")) + .groupBy($"DEPT") + .agg(sum($"value").as("total")) + .filter($"total" > 1000) + // query1.explain(true) + checkAnswer(query1, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000))) + val decrease = udf { (x: Double, y: Double) => x - y} + val query2 = df1.select($"DEPT", decrease($"SALARY", $"BONUS").as("value"), $"SALARY", $"BONUS") + .groupBy($"DEPT") + .agg(sum($"value"), sum($"SALARY"), sum($"BONUS")) + // query2.explain(true) + checkAnswer(query2, + Seq(Row(1, 16800.00, 19000.00, 2200.00), Row(2, 19500.00, 22000.00, 2500.00), + Row(6, 10800, 12000, 1200))) + + val cols = Seq("a", "b", "c", "d") + val df2 = sql("select * from h2.test.employee").toDF(cols: _*) + val df3 = df2.groupBy().sum("c") + // df3.explain(true) + checkAnswer(df3, Seq(Row(53000.00))) + + val df4 = df2.groupBy($"a").sum("c") + checkAnswer(df4, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000))) + } + + test("scan with aggregate push-down") { + val df1 = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + " group by DEPT") + // df1.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Aggregate ['DEPT], [unresolvedalias('MAX('SALARY), None), unresolvedalias('MIN('BONUS), None)] + // +- 'Filter ('dept > 0) + // +- 'UnresolvedRelation [h2, test, employee], [] + // + // == Analyzed Logical Plan == + // max(SALARY): int, min(BONUS): int + // Aggregate [DEPT#0], [max(SALARY#2) AS max(SALARY)#6, min(BONUS#3) AS min(BONUS)#7] + // +- Filter (dept#0 > 0) + // +- SubqueryAlias h2.test.employee + // +- RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3] test.employee + // + // == Optimized Logical Plan == + // Aggregate [DEPT#0], [max(max(SALARY)#13) AS max(SALARY)#6, min(min(BONUS)#14) AS min(BONUS)#7] + // +- RelationV2[DEPT#0, max(SALARY)#13, min(BONUS)#14] test.employee + // + // == Physical Plan == + // *(2) HashAggregate(keys=[DEPT#0], functions=[max(max(SALARY)#13), min(min(BONUS)#14)], output=[max(SALARY)#6, min(BONUS)#7]) + // +- Exchange hashpartitioning(DEPT#0, 5), true, [id=#10] + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@3d9f0a5 [DEPT#0,max(SALARY)#13,min(BONUS)#14] PushedAggregates: [*Max(SALARY,false,None), *Min(BONUS,false,None)], PushedFilters: [IsNotNull(dept), GreaterThan(dept,0)], PushedGroupby: [*DEPT], ReadSchema: struct// scalastyle:on line.size.limit + // + // df1.show + // +-----------+----------+ + // |max(SALARY)|min(BONUS)| + // +-----------+----------+ + // | 10000| 1000| + // | 12000| 1200| + // | 12000| 1200| + // +-----------+----------+ + checkAnswer(df1, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) + + val df2 = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0") + // df2.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias('MAX('ID), None), unresolvedalias('MIN('ID), None)] + // +- 'Filter ('id > 0) + // +- 'UnresolvedRelation [h2, test, people], [] + // + // == Analyzed Logical Plan == + // max(ID): int, min(ID): int + // Aggregate [max(ID#29) AS max(ID)#32, min(ID#29) AS min(ID)#33] + // +- Filter (id#29 > 0) + // +- SubqueryAlias h2.test.people + // +- RelationV2[NAME#28, ID#29] test.people + // + // == Optimized Logical Plan == + // Aggregate [max(max(ID)#37) AS max(ID)#32, min(min(ID)#38) AS min(ID)#33] + // +- RelationV2[max(ID)#37, min(ID)#38] test.people + // + // == Physical Plan == + // *(2) HashAggregate(keys=[], functions=[max(max(ID)#37), min(min(ID)#38)], output=[max(ID)#32, min(ID)#33]) + // +- Exchange SinglePartition, true, [id=#44] + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@5ed31735 [max(ID)#37,min(ID)#38] PushedAggregates: [*Max(ID,false,None), *Min(ID,false,None)], PushedFilters: [IsNotNull(id), GreaterThan(id,0)], PushedGroupby: [], ReadSchema: struct + // scalastyle:on line.size.limit + + // df2.show() + // +-------+-------+ + // |max(ID)|min(ID)| + // +-------+-------+ + // | 2| 1| + // +-------+-------+ + checkAnswer(df2, Seq(Row(2, 1))) + + val df3 = sql("select AVG(ID) FROM h2.test.people where id > 0") + checkAnswer(df3, Seq(Row(1.5))) + + val df4 = sql("select MAX(SALARY) + 1 FROM h2.test.employee") + // df4.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias(('MAX('SALARY) + 1), None)] + // +- 'UnresolvedRelation [h2, test, employee], [] + // + // == Analyzed Logical Plan == + // (max(SALARY) + 1): int + // Aggregate [(max(SALARY#68) + 1) AS (max(SALARY) + 1)#71] + // +- SubqueryAlias h2.test.employee + // +- RelationV2[DEPT#66, NAME#67, SALARY#68, BONUS#69] test.employee + // + // == Optimized Logical Plan == + // Aggregate [(max((max(SALARY) + 1)#74) + 1) AS (max(SALARY) + 1)#71] + // +- RelationV2[(max(SALARY) + 1)#74] test.employee + // + // == Physical Plan == + // *(2) HashAggregate(keys=[], functions=[max((max(SALARY) + 1)#74)], output=[(max(SALARY) + 1)#71]) + // +- Exchange SinglePartition, true, [id=#112] + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@20864cd1 [(max(SALARY) + 1)#74] PushedAggregates: [*Max(SALARY,false,None)], PushedFilters: [], PushedGroupby: [], ReadSchema: struct<(max(SALARY) + 1):int> + // scalastyle:on line.size.limit + checkAnswer(df4, Seq(Row(12001))) + + // COUNT push down is not supported yet + val df5 = sql("select COUNT(*) FROM h2.test.employee") + // df5.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias('COUNT(1), None)] + // +- 'UnresolvedRelation [h2, test, employee], [] + // + // == Analyzed Logical Plan == + // count(1): bigint + // Aggregate [count(1) AS count(1)#87L] + // +- SubqueryAlias h2.test.employee + // +- RelationV2[DEPT#82, NAME#83, SALARY#84, BONUS#85] test.employee + // + // == Optimized Logical Plan == + // Aggregate [count(1) AS count(1)#87L] + // +- RelationV2[] test.employee + // + // == Physical Plan == + // *(2) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#87L]) + // *(2) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#87L]) + // +- Exchange SinglePartition, true, [id=#149] + // +- *(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#90L]) + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@63262071 [] PushedAggregates: [], PushedFilters: [], PushedGroupby: [], ReadSchema: struct<> + // scalastyle:on line.size.limit + checkAnswer(df5, Seq(Row(5))) + + val df6 = sql("select MIN(SALARY), MIN(BONUS), MIN(SALARY) * MIN(BONUS) FROM h2.test.employee") + // df6.explain(true) + checkAnswer(df6, Seq(Row(9000, 1000, 9000000))) + + val df7 = sql("select MIN(SALARY), MIN(BONUS), SUM(SALARY * BONUS) FROM h2.test.employee") + // df7.explain(true) + checkAnswer(df7, Seq(Row(9000, 1000, 62600000))) + + val df8 = sql("select BONUS, SUM(SALARY+BONUS), SALARY FROM h2.test.employee" + + " GROUP BY SALARY, BONUS") + // df8.explain(true) + checkAnswer(df8, Seq(Row(1000, 11000, 10000), Row(1200, 26400, 12000), + Row(1200, 10200, 9000), Row(1300, 11300, 10000))) + + val df9 = spark.table("h2.test.employee") + val sub2 = udf { (x: String) => x.substring(0, 3) } + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val df10 = df9.select($"SALARY", $"BONUS", sub2($"NAME").as("nsub2")) + .filter("SALARY > 100") + .filter(name($"nsub2")) + .agg(avg($"SALARY").as("avg_salary")) + // df10.explain(true) + checkAnswer(df10, Seq(Row(9666.666667))) + + val df11 = sql("select SUM(SALARY+BONUS*SALARY+SALARY/BONUS), SALARY FROM h2.test.employee" + + " GROUP BY SALARY, BONUS") + checkAnswer(df11, Seq(Row(10010010.000000000, 10000.00), Row(28824020.000000000, 12000.00), + Row(10809007.500000000, 9000.00), Row(13010007.692307692, 10000.00))) + } + + test("scan with aggregate distinct push-down") { + checkAnswer(sql("SELECT SUM(SALARY) FROM h2.test.employee"), Seq(Row(53000))) + checkAnswer(sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee"), Seq(Row(31000))) + checkAnswer(sql("SELECT AVG(DEPT) FROM h2.test.employee"), Seq(Row(2.4))) + checkAnswer(sql("SELECT AVG(DISTINCT DEPT) FROM h2.test.employee"), Seq(Row(3))) + } + test("read/write with partition info") { withTable("h2.test.abc") { sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people") @@ -145,7 +345,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), - Seq(Row("test", "people", false), Row("test", "empty_table", false))) + Seq(Row("test", "people", false), Row("test", "empty_table", false), + Row("test", "employee", false))) } test("SQL API: create table as select") { From dcc16a1cc6aa7d5f7323f77f44a766a5f6e785bd Mon Sep 17 00:00:00 2001 From: "chang.chen" Date: Wed, 7 Apr 2021 12:33:07 +0800 Subject: [PATCH 2/2] fix typo --- .../src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 87fcdd0751a97..a9a52791f77c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -122,7 +122,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { checkAnswer(df, Row("mary")) } - test("aggregate pushdown with alias") { + test("aggregate push-down with alias") { val df1 = spark.table("h2.test.employee") val query1 = df1.select($"DEPT", $"SALARY".as("value")) .groupBy($"DEPT")