Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/sql-data-sources-jdbc.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ the following case-insensitive options:
Specifies kerberos principal name for the JDBC client. If both <code>keytab</code> and <code>principal</code> are defined then Spark tries to do kerberos authentication.
</td>
</tr>

<tr>
<td><code>pushDownAggregate</code></td>
<td>
The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will NOT push down aggregates to the JDBC data source. Otherwise, if set to true, aggregate will be pushed down to the JDBC data source and thus aggregates will be handled by data source instead of Spark. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates are pushed down if and only if all the aggregates and the related filters can be pushed down.
</td>
</tr>
</table>

Note that kerberos authentication with keytab is not always supported by the JDBC driver.<br>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.Aggregation;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down aggregates to the data source.
*
* @since 3.2.0
*/
@Evolving
public interface SupportsPushDownAggregates extends ScanBuilder {

/**
* Pushes down Aggregation to datasource.
* The Aggregation can be pushed down only if all the Aggregate Functions can
* be pushed down.
*/
void pushAggregation(Aggregation aggregation);

/**
* Returns the aggregation that are pushed to the data source via
* {@link #pushAggregation(Aggregation aggregation)}.
*/
Aggregation pushedAggregation();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

// scalastyle:off line.size.limit
@ExpressionDescription(
Expand All @@ -44,52 +40,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
// scalastyle:on line.size.limit
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {

override def nullable: Boolean = false

// Return data type.
override def dataType: DataType = LongType

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
s"If you have to call the function $prettyName without arguments, set the legacy " +
s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

protected lazy val count = AttributeReference("count", LongType, nullable = false)()

override lazy val aggBufferAttributes = count :: Nil

override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)

override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)

override lazy val evaluateExpression = count

override def defaultResult: Option[Literal] = Option(Literal(0L))

override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}
}
case class Count(children: Seq[Expression]) extends CountBase(children)

object Count {
def apply(child: Expression): Count = Count(child :: Nil)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

abstract class CountBase(children: Seq[Expression]) extends DeclarativeAggregate {

override def nullable: Boolean = false

// Return data type.
override def dataType: DataType = LongType

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
s"If you have to call the function $prettyName without arguments, set the legacy " +
s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

protected lazy val count = AttributeReference("count", LongType, nullable = false)()

override lazy val aggBufferAttributes = count :: Nil

override lazy val initialValues = Seq(
/* count = */ Literal(0L)
)

override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)

override lazy val evaluateExpression = count

override def defaultResult: Option[Literal] = Option(Literal(0L))

override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(
/* count = */ count + 1L
)
} else {
Seq(
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
)
}
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.LongType

case class PushDownCount(children: Seq[Expression]) extends CountBase(children) {

override protected lazy val count =
AttributeReference("PushDownCount", LongType, nullable = false)()

override lazy val updateExpressions = {
Seq(
// if count is pushed down to Data Source layer, add the count result retrieved from
// Data Source
/* count = */ count + children.head
)
}
}

object PushDownCount {
def apply(child: Expression): PushDownCount =
PushDownCount(child :: Nil)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

import org.apache.spark.sql.types.DataType

case class Aggregation(aggregateExpressions: Seq[AggregateFunc],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these need some docs since they are user-facing? and maybe some examples on how to handle aggregateExpressions and groupByExpressions. For the latter, should we also name it groupByColumns?

groupByExpressions: Seq[String])

abstract class AggregateFunc

case class Avg(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
case class Min(column: String, dataType: DataType) extends AggregateFunc
case class Max(column: String, dataType: DataType) extends AggregateFunc
case class Sum(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc
case class Count(column: String, dataType: DataType, isDistinct: Boolean) extends AggregateFunc

object Aggregation {
// Returns an empty Aggregate
def empty: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -102,6 +102,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
aggregation: Aggregation,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
Expand Down Expand Up @@ -132,9 +133,17 @@ case class RowDataSourceScanExec(
val markedFilters = for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
}
val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield {
s"*$aggregate"
}
val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield {
s"*$groupby"
}
Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"),
"PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"),
"PushedGroupby" -> markedGroupby.mkString("[", ", ", "]"))
}

// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{CacheTable, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, UncacheTable}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -358,6 +359,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
Aggregation.empty,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -431,6 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
Aggregation.empty,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -453,6 +456,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
Aggregation.empty,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down Expand Up @@ -700,6 +704,49 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
}

private def columnAsString(e: Expression): String = e match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For predicate pushdown, seems we simplify the cases to handle by only looking at column name.

This covers a lot of cases but also makes it easy to break. We can begin with simplest case and add more supports later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait for others. See if there is any other voices.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This covers a lot of cases but also makes it easy to break. We can begin with simplest case and add more supports later.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. It also seems strange to convert binary expression into a "magic" string form that is (seems) special to JDBC datasources.

I also wonder if we should handle nested columns the same way as PushableColumnBase

case AttributeReference(name, _, _, _) => name
case Cast(child, _, _) => columnAsString (child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after columnAsString.

case Add(left, right, _) =>
columnAsString(left) + " + " + columnAsString(right)
case Subtract(left, right, _) =>
columnAsString(left) + " - " + columnAsString(right)
case Multiply(left, right, _) =>
columnAsString(left) + " * " + columnAsString(right)
case Divide(left, right, _) =>
columnAsString(left) + " / " + columnAsString(right)
case CheckOverflow(child, _, _) => columnAsString (child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after columnAsString.

case PromotePrecision(child) => columnAsString (child)
case _ => ""
}

protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
aggregates.aggregateFunction match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if aggregates has isDistinct=true or filter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will need to change the following to add isDistinct and filter. Also change translateAggregate accordingly. When push down the aggregates, need to check the filter to make sure it can be pushed down too.

case class Avg(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case class Min(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case class Max(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case class Sum(column: String, isDistinct: Boolean, filter: Option[Filter]) extends AggregateFunc

case min: aggregate.Min =>
val colName = columnAsString(min.child)
if (colName.nonEmpty) Some(Min(colName, min.dataType)) else None
case max: aggregate.Max =>
val colName = columnAsString(max.child)
if (colName.nonEmpty) Some(Max(colName, max.dataType)) else None
case avg: aggregate.Average =>
val colName = columnAsString(avg.child)
if (colName.nonEmpty) Some(Avg(colName, avg.dataType, aggregates.isDistinct)) else None
case sum: aggregate.Sum =>
val colName = columnAsString(sum.child)
if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None
case count: aggregate.Count =>
val columnName = count.children.head match {
case Literal(_, _) => "1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is "1"? also should we check if there is more than one elements in children?

case _ => columnAsString(count.children.head)
}
if (columnName.nonEmpty) {
Some(Count(columnName, count.dataType, aggregates.isDistinct))
}
else None
case _ => None
}
}

/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
Expand Down
Loading