-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22390][SPARK-32833][SQL] JDBC V2 Datasource aggregate push down #29695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
c51c2d3
0bb605a
0b42b6c
7f73f6c
861d190
8ca31b3
92f60af
36e5b1a
f29721b
fdec55c
9adaf6f
62750b7
1c9b2fd
2a09408
4e5e307
1e692a3
6c8c0f4
d9064b5
ef5c496
e9e984e
69813c7
ef4bab9
782a0a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.connector.read; | ||
|
|
||
| import org.apache.spark.annotation.Evolving; | ||
| import org.apache.spark.sql.sources.Aggregation; | ||
|
|
||
| /** | ||
| * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to | ||
| * push down aggregates to the data source. | ||
| * | ||
| * @since 3.1.0 | ||
| */ | ||
| @Evolving | ||
| public interface SupportsPushDownAggregates extends ScanBuilder { | ||
|
|
||
| /** | ||
| * Pushes down Aggregation to datasource. | ||
| * The Aggregation can be pushed down only if all the Aggregate Functions can | ||
| * be pushed down. | ||
| */ | ||
| void pushAggregation(Aggregation aggregation); | ||
|
|
||
| /** | ||
| * Returns the aggregates that are pushed to the data source via | ||
|
||
| * {@link #pushAggregation(Aggregation aggregation)}. | ||
| */ | ||
| Aggregation pushedAggregation(); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.sources | ||
|
|
||
| import org.apache.spark.sql.types.DataType | ||
|
|
||
| case class Aggregation(aggregateExpressions: Seq[AggregateFunc], | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these need some docs since they are user-facing? and maybe some examples on how to handle |
||
| groupByExpressions: Seq[String]) | ||
|
|
||
| abstract class AggregateFunc | ||
|
|
||
| // Todo: add Count | ||
|
|
||
| case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc | ||
| case class Min(column: String, dataType: DataType) extends AggregateFunc | ||
| case class Max(column: String, dataType: DataType) extends AggregateFunc | ||
| case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc | ||
|
|
||
| object Aggregation { | ||
| // Returns an empty Aggregate | ||
| def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._ | |
| import org.apache.spark.sql.catalyst.encoders.RowEncoder | ||
| import org.apache.spark.sql.catalyst.expressions | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
| import org.apache.spark.sql.catalyst.planning.ScanOperation | ||
| import org.apache.spark.sql.catalyst.plans.logical.{CacheTable, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, UncacheTable} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
|
|
@@ -358,6 +359,7 @@ object DataSourceStrategy | |
| l.output.toStructType, | ||
| Set.empty, | ||
| Set.empty, | ||
| Aggregation.empty, | ||
| toCatalystRDD(l, baseRelation.buildScan()), | ||
| baseRelation, | ||
| None) :: Nil | ||
|
|
@@ -431,6 +433,7 @@ object DataSourceStrategy | |
| requestedColumns.toStructType, | ||
| pushedFilters.toSet, | ||
| handledFilters, | ||
| Aggregation.empty, | ||
| scanBuilder(requestedColumns, candidatePredicates, pushedFilters), | ||
| relation.relation, | ||
| relation.catalogTable.map(_.identifier)) | ||
|
|
@@ -453,6 +456,7 @@ object DataSourceStrategy | |
| requestedColumns.toStructType, | ||
| pushedFilters.toSet, | ||
| handledFilters, | ||
| Aggregation.empty, | ||
| scanBuilder(requestedColumns, candidatePredicates, pushedFilters), | ||
| relation.relation, | ||
| relation.catalogTable.map(_.identifier)) | ||
|
|
@@ -700,6 +704,41 @@ object DataSourceStrategy | |
| (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) | ||
| } | ||
|
|
||
| private def columnAsString(e: Expression): String = e match { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For predicate pushdown, seems we simplify the cases to handle by only looking at column name. This covers a lot of cases but also makes it easy to break. We can begin with simplest case and add more supports later.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's wait for others. See if there is any other voices.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
+1
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1. It also seems strange to convert binary expression into a "magic" string form that is (seems) special to JDBC datasources. I also wonder if we should handle nested columns the same way as |
||
| case AttributeReference(name, _, _, _) => name | ||
| case Cast(child, _, _) => columnAsString (child) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra space after |
||
| case Add(left, right, _) => | ||
| columnAsString(left) + " + " + columnAsString(right) | ||
| case Subtract(left, right, _) => | ||
| columnAsString(left) + " - " + columnAsString(right) | ||
| case Multiply(left, right, _) => | ||
| columnAsString(left) + " * " + columnAsString(right) | ||
| case Divide(left, right, _) => | ||
| columnAsString(left) + " / " + columnAsString(right) | ||
| case CheckOverflow(child, _, _) => columnAsString (child) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra space after |
||
| case PromotePrecision(child) => columnAsString (child) | ||
| case _ => "" | ||
| } | ||
|
|
||
| protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { | ||
|
|
||
|
||
| aggregates.aggregateFunction match { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will need to change the following to add |
||
| case min: aggregate.Min => | ||
| val colName = columnAsString(min.child) | ||
| if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None | ||
| case max: aggregate.Max => | ||
| val colName = columnAsString(max.child) | ||
| if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None | ||
| case avg: aggregate.Average => | ||
| val colName = columnAsString(avg.child) | ||
| if (colName.nonEmpty) Some(Avg(colName, avg.dataType, aggregates.isDistinct)) else None | ||
| case sum: aggregate.Sum => | ||
| val colName = columnAsString(sum.child) | ||
| if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None | ||
| case _ => None | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Convert RDD of Row into RDD of InternalRow with objects in catalyst types | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,9 @@ | |
| package org.apache.spark.sql.execution.datasources.jdbc | ||
|
|
||
| import java.sql.{Connection, PreparedStatement, ResultSet} | ||
| import java.util.StringTokenizer | ||
|
|
||
| import scala.collection.mutable.ArrayBuilder | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} | ||
|
|
@@ -133,6 +135,63 @@ object JDBCRDD extends Logging { | |
| }) | ||
| } | ||
|
|
||
| private def containsArithmeticOp(col: String): Boolean = | ||
| col.contains("+") || col.contains("-") || col.contains("*") || col.contains("/") | ||
|
|
||
| def compileAggregates( | ||
| aggregates: Seq[AggregateFunc], | ||
| dialect: JdbcDialect): (Array[String], Array[DataType]) = { | ||
| def quote(colName: String): String = dialect.quoteIdentifier(colName) | ||
| val aggBuilder = ArrayBuilder.make[String] | ||
| val dataTypeBuilder = ArrayBuilder.make[DataType] | ||
| aggregates.map { | ||
| case Min(column, dataType) => | ||
| dataTypeBuilder += dataType | ||
| if (!containsArithmeticOp(column)) { | ||
| aggBuilder += s"MIN(${quote(column)})" | ||
| } else { | ||
| aggBuilder += s"MIN(${quoteEachCols(column, dialect)})" | ||
| } | ||
| case Max(column, dataType) => | ||
| dataTypeBuilder += dataType | ||
| if (!containsArithmeticOp(column)) { | ||
| aggBuilder += s"MAX(${quote(column)})" | ||
| } else { | ||
| aggBuilder += s"MAX(${quoteEachCols(column, dialect)})" | ||
| } | ||
| case Sum(column, dataType, isDistinct) => | ||
| val distinct = if (isDistinct) "DISTINCT " else "" | ||
| dataTypeBuilder += dataType | ||
| if (!containsArithmeticOp(column)) { | ||
| aggBuilder += s"SUM(${distinct} ${quote(column)})" | ||
| } else { | ||
| aggBuilder += s"SUM(${distinct} ${quoteEachCols(column, dialect)})" | ||
| } | ||
| case Avg(column, dataType, isDistinct) => | ||
| val distinct = if (isDistinct) "DISTINCT " else "" | ||
| dataTypeBuilder += dataType | ||
| if (!containsArithmeticOp(column)) { | ||
| aggBuilder += s"AVG(${distinct} ${quote(column)})" | ||
| } else { | ||
| aggBuilder += s"AVG(${distinct} ${quoteEachCols(column, dialect)})" | ||
| } | ||
| case _ => | ||
| } | ||
| (aggBuilder.result, dataTypeBuilder.result) | ||
| } | ||
|
|
||
| private def quoteEachCols (column: String, dialect: JdbcDialect): String = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra space |
||
| def quote(colName: String): String = dialect.quoteIdentifier(colName) | ||
| val colsBuilder = ArrayBuilder.make[String] | ||
| val st = new StringTokenizer(column, "+-*/", true) | ||
| colsBuilder += quote(st.nextToken().trim) | ||
| while (st.hasMoreTokens) { | ||
| colsBuilder += st.nextToken | ||
| colsBuilder += quote(st.nextToken().trim) | ||
| } | ||
| colsBuilder.result.mkString(" ") | ||
| } | ||
|
|
||
| /** | ||
| * Build and return JDBCRDD from the given information. | ||
| * | ||
|
|
@@ -152,7 +211,9 @@ object JDBCRDD extends Logging { | |
| requiredColumns: Array[String], | ||
| filters: Array[Filter], | ||
| parts: Array[Partition], | ||
| options: JDBCOptions): RDD[InternalRow] = { | ||
| options: JDBCOptions, | ||
| aggregation: Aggregation = Aggregation.empty) | ||
| : RDD[InternalRow] = { | ||
| val url = options.url | ||
| val dialect = JdbcDialects.get(url) | ||
| val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) | ||
|
|
@@ -164,7 +225,8 @@ object JDBCRDD extends Logging { | |
| filters, | ||
| parts, | ||
| url, | ||
| options) | ||
| options, | ||
| aggregation) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -181,21 +243,49 @@ private[jdbc] class JDBCRDD( | |
| filters: Array[Filter], | ||
| partitions: Array[Partition], | ||
| url: String, | ||
| options: JDBCOptions) | ||
| options: JDBCOptions, | ||
| aggregation: Aggregation = Aggregation.empty) | ||
| extends RDD[InternalRow](sc, Nil) { | ||
|
|
||
| /** | ||
| * Retrieve the list of partitions corresponding to this RDD. | ||
| */ | ||
| override def getPartitions: Array[Partition] = partitions | ||
|
|
||
| private var updatedSchema: StructType = new StructType() | ||
|
||
|
|
||
| /** | ||
| * `columns`, but as a String suitable for injection into a SQL query. | ||
| */ | ||
| private val columnList: String = { | ||
| val (compiledAgg, aggDataType) = | ||
| JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url)) | ||
| val sb = new StringBuilder() | ||
| columns.foreach(x => sb.append(",").append(x)) | ||
| if (sb.isEmpty) "1" else sb.substring(1) | ||
| if (compiledAgg.length == 0) { | ||
| updatedSchema = schema | ||
| columns.foreach(x => sb.append(",").append(x)) | ||
| } else { | ||
| getAggregateColumnsList(sb, compiledAgg, aggDataType) | ||
|
||
| } | ||
| if (sb.length == 0) "1" else sb.substring(1) | ||
| } | ||
|
|
||
| private def getAggregateColumnsList( | ||
|
||
| sb: StringBuilder, | ||
| compiledAgg: Array[String], | ||
| aggDataType: Array[DataType]): Unit = { | ||
| val colDataTypeMap: Map[String, StructField] = columns.zip(schema.fields).toMap | ||
| val newColsBuilder = ArrayBuilder.make[String] | ||
| for ((col, dataType) <- compiledAgg.zip(aggDataType)) { | ||
| newColsBuilder += col | ||
| updatedSchema = updatedSchema.add(col, dataType) | ||
| } | ||
| for (groupBy <- aggregation.groupByExpressions) { | ||
| val quotedGroupBy = JdbcDialects.get(url).quoteIdentifier(groupBy) | ||
| newColsBuilder += quotedGroupBy | ||
| updatedSchema = updatedSchema.add(colDataTypeMap.get(quotedGroupBy).get) | ||
| } | ||
| sb.append(", ").append(newColsBuilder.result.mkString(", ")) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -221,6 +311,18 @@ private[jdbc] class JDBCRDD( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * A GROUP BY clause representing pushed-down grouping columns. | ||
| */ | ||
| private def getGroupByClause: String = { | ||
| if (aggregation.groupByExpressions.length > 0) { | ||
| val quotedColumns = aggregation.groupByExpressions.map(JdbcDialects.get(url).quoteIdentifier) | ||
| s"GROUP BY ${quotedColumns.mkString(", ")}" | ||
| } else { | ||
| "" | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Runs the SQL query against the JDBC driver. | ||
| * | ||
|
|
@@ -296,13 +398,15 @@ private[jdbc] class JDBCRDD( | |
|
|
||
| val myWhereClause = getWhereClause(part) | ||
|
|
||
| val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" | ||
| val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + | ||
| s" $getGroupByClause" | ||
| stmt = conn.prepareStatement(sqlText, | ||
| ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) | ||
| stmt.setFetchSize(options.fetchSize) | ||
| stmt.setQueryTimeout(options.queryTimeout) | ||
| rs = stmt.executeQuery() | ||
| val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) | ||
|
|
||
|
||
| val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, updatedSchema, inputMetrics) | ||
|
|
||
| CompletionIterator[InternalRow, Iterator[InternalRow]]( | ||
| new InterruptibleIterator(context, rowsIterator), close()) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe 3.2.0 now.