Skip to content
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.Aggregation;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down aggregates to the data source.
*
* @since 3.1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe 3.2.0 now.

*/
@Evolving
public interface SupportsPushDownAggregates extends ScanBuilder {

/**
* Pushes down Aggregation to datasource.
* The Aggregation can be pushed down only if all the Aggregate Functions can
* be pushed down.
*/
void pushAggregation(Aggregation aggregation);

/**
* Returns the aggregates that are pushed to the data source via
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this returns Aggregation?

* {@link #pushAggregation(Aggregation aggregation)}.
*/
Aggregation pushedAggregation();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

import org.apache.spark.sql.types.DataType

case class Aggregation(aggregateExpressions: Seq[AggregateFunc],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these need some docs since they are user-facing? and maybe some examples on how to handle aggregateExpressions and groupByExpressions. For the latter, should we also name it groupByColumns?

groupByExpressions: Seq[String])

abstract class AggregateFunc

// Todo: add Count

case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
case class Min(column: String, dataType: DataType) extends AggregateFunc
case class Max(column: String, dataType: DataType) extends AggregateFunc
case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc

object Aggregation {
// Returns an empty Aggregate
def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -102,6 +102,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
aggregation: Aggregation,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
Expand Down Expand Up @@ -132,9 +133,17 @@ case class RowDataSourceScanExec(
val markedFilters = for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
}
val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield {
s"*$aggregate"
}
val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield {
s"*$groupby"
}
Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"),
"PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"),
"PushedGroupby" -> markedGroupby.mkString("[", ", ", "]"))
}

// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{CacheTable, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, UncacheTable}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -358,6 +359,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
Aggregation.empty,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -431,6 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
Aggregation.empty,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -453,6 +456,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
Aggregation.empty,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down Expand Up @@ -700,6 +704,41 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
}

private def columnAsString(e: Expression): String = e match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For predicate pushdown, seems we simplify the cases to handle by only looking at column name.

This covers a lot of cases but also makes it easy to break. We can begin with simplest case and add more supports later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait for others. See if there is any other voices.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This covers a lot of cases but also makes it easy to break. We can begin with simplest case and add more supports later.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. It also seems strange to convert binary expression into a "magic" string form that is (seems) special to JDBC datasources.

I also wonder if we should handle nested columns the same way as PushableColumnBase

case AttributeReference(name, _, _, _) => name
case Cast(child, _, _) => columnAsString (child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after columnAsString.

case Add(left, right, _) =>
columnAsString(left) + " + " + columnAsString(right)
case Subtract(left, right, _) =>
columnAsString(left) + " - " + columnAsString(right)
case Multiply(left, right, _) =>
columnAsString(left) + " * " + columnAsString(right)
case Divide(left, right, _) =>
columnAsString(left) + " / " + columnAsString(right)
case CheckOverflow(child, _, _) => columnAsString (child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after columnAsString.

case PromotePrecision(child) => columnAsString (child)
case _ => ""
}

protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary blank.

aggregates.aggregateFunction match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if aggregates has isDistinct=true or filter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will need to change the following to add isDistinct and filter. Also change translateAggregate accordingly. When push down the aggregates, need to check the filter to make sure it can be pushed down too.

case class Avg(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case class Min(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case class Max(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case class Sum(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case min: aggregate.Min =>
val colName = columnAsString(min.child)
if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None
case max: aggregate.Max =>
val colName = columnAsString(max.child)
if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None
case avg: aggregate.Average =>
val colName = columnAsString(avg.child)
if (colName.nonEmpty) Some(Avg(colName, avg.dataType, aggregates.isDistinct)) else None
case sum: aggregate.Sum =>
val colName = columnAsString(sum.child)
if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None
case _ => None
}
}

/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ class JDBCOptions(
// An option to allow/disallow pushing down predicate into JDBC data source
val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean

// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "true").toBoolean

// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
Expand Down Expand Up @@ -260,6 +263,7 @@ object JDBCOptions {
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, PreparedStatement, ResultSet}
import java.util.StringTokenizer

import scala.collection.mutable.ArrayBuilder
import scala.util.control.NonFatal

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
Expand Down Expand Up @@ -133,6 +135,63 @@ object JDBCRDD extends Logging {
})
}

private def containsArithmeticOp(col: String): Boolean =
col.contains("+") || col.contains("-") || col.contains("*") || col.contains("/")

def compileAggregates(
aggregates: Seq[AggregateFunc],
dialect: JdbcDialect): (Array[String], Array[DataType]) = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)
val aggBuilder = ArrayBuilder.make[String]
val dataTypeBuilder = ArrayBuilder.make[DataType]
aggregates.map {
case Min(column, dataType) =>
dataTypeBuilder += dataType
if (!containsArithmeticOp(column)) {
aggBuilder += s"MIN(${quote(column)})"
} else {
aggBuilder += s"MIN(${quoteEachCols(column, dialect)})"
}
case Max(column, dataType) =>
dataTypeBuilder += dataType
if (!containsArithmeticOp(column)) {
aggBuilder += s"MAX(${quote(column)})"
} else {
aggBuilder += s"MAX(${quoteEachCols(column, dialect)})"
}
case Sum(column, dataType, isDistinct) =>
val distinct = if (isDistinct) "DISTINCT " else ""
dataTypeBuilder += dataType
if (!containsArithmeticOp(column)) {
aggBuilder += s"SUM(${distinct} ${quote(column)})"
} else {
aggBuilder += s"SUM(${distinct} ${quoteEachCols(column, dialect)})"
}
case Avg(column, dataType, isDistinct) =>
val distinct = if (isDistinct) "DISTINCT " else ""
dataTypeBuilder += dataType
if (!containsArithmeticOp(column)) {
aggBuilder += s"AVG(${distinct} ${quote(column)})"
} else {
aggBuilder += s"AVG(${distinct} ${quoteEachCols(column, dialect)})"
}
case _ =>
}
(aggBuilder.result, dataTypeBuilder.result)
}

private def quoteEachCols (column: String, dialect: JdbcDialect): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space

def quote(colName: String): String = dialect.quoteIdentifier(colName)
val colsBuilder = ArrayBuilder.make[String]
val st = new StringTokenizer(column, "+-*/", true)
colsBuilder += quote(st.nextToken().trim)
while (st.hasMoreTokens) {
colsBuilder += st.nextToken
colsBuilder += quote(st.nextToken().trim)
}
colsBuilder.result.mkString(" ")
}

/**
* Build and return JDBCRDD from the given information.
*
Expand All @@ -152,7 +211,9 @@ object JDBCRDD extends Logging {
requiredColumns: Array[String],
filters: Array[Filter],
parts: Array[Partition],
options: JDBCOptions): RDD[InternalRow] = {
options: JDBCOptions,
aggregation: Aggregation = Aggregation.empty)
: RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
Expand All @@ -164,7 +225,8 @@ object JDBCRDD extends Logging {
filters,
parts,
url,
options)
options,
aggregation)
}
}

Expand All @@ -181,21 +243,49 @@ private[jdbc] class JDBCRDD(
filters: Array[Filter],
partitions: Array[Partition],
url: String,
options: JDBCOptions)
options: JDBCOptions,
aggregation: Aggregation = Aggregation.empty)
extends RDD[InternalRow](sc, Nil) {

/**
* Retrieve the list of partitions corresponding to this RDD.
*/
override def getPartitions: Array[Partition] = partitions

private var updatedSchema: StructType = new StructType()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid to use var here?


/**
* `columns`, but as a String suitable for injection into a SQL query.
*/
private val columnList: String = {
val (compiledAgg, aggDataType) =
JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url))
val sb = new StringBuilder()
columns.foreach(x => sb.append(",").append(x))
if (sb.isEmpty) "1" else sb.substring(1)
if (compiledAgg.length == 0) {
updatedSchema = schema
columns.foreach(x => sb.append(",").append(x))
} else {
getAggregateColumnsList(sb, compiledAgg, aggDataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

columns is empty for this case?

}
if (sb.length == 0) "1" else sb.substring(1)
}

private def getAggregateColumnsList(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a comment here to explain what getAggregateColumnsList does and why it is needed?

sb: StringBuilder,
compiledAgg: Array[String],
aggDataType: Array[DataType]): Unit = {
val colDataTypeMap: Map[String, StructField] = columns.zip(schema.fields).toMap
val newColsBuilder = ArrayBuilder.make[String]
for ((col, dataType) <- compiledAgg.zip(aggDataType)) {
newColsBuilder += col
updatedSchema = updatedSchema.add(col, dataType)
}
for (groupBy <- aggregation.groupByExpressions) {
val quotedGroupBy = JdbcDialects.get(url).quoteIdentifier(groupBy)
newColsBuilder += quotedGroupBy
updatedSchema = updatedSchema.add(colDataTypeMap.get(quotedGroupBy).get)
}
sb.append(", ").append(newColsBuilder.result.mkString(", "))
}

/**
Expand All @@ -221,6 +311,18 @@ private[jdbc] class JDBCRDD(
}
}

/**
* A GROUP BY clause representing pushed-down grouping columns.
*/
private def getGroupByClause: String = {
if (aggregation.groupByExpressions.length > 0) {
val quotedColumns = aggregation.groupByExpressions.map(JdbcDialects.get(url).quoteIdentifier)
s"GROUP BY ${quotedColumns.mkString(", ")}"
} else {
""
}
}

/**
* Runs the SQL query against the JDBC driver.
*
Expand Down Expand Up @@ -296,13 +398,15 @@ private[jdbc] class JDBCRDD(

val myWhereClause = getWhereClause(part)

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause"
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
s" $getGroupByClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change?

val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, updatedSchema, inputMetrics)

CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
Expand Down
Loading