Skip to content

Commit 5523402

Browse files
committed
simplify JDBC aggregate pushdown
1 parent 86c4227 commit 5523402

13 files changed

Lines changed: 123 additions & 114 deletions

File tree

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ public Count(FieldReference column, boolean isDistinct) {
3838
public boolean isDistinct() { return isDistinct; }
3939

4040
@Override
41-
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
41+
public String toString() {
42+
if (isDistinct) {
43+
return "COUNT(DISTINCT " + column.describe() + ")";
44+
} else {
45+
return "COUNT(" + column.describe() + ")";
46+
}
47+
}
4248

4349
@Override
4450
public String describe() { return this.toString(); }

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public CountStar() {
3131
}
3232

3333
@Override
34-
public String toString() { return "CountStar()"; }
34+
public String toString() { return "COUNT(*)"; }
3535

3636
@Override
3737
public String describe() { return this.toString(); }

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public final class Max implements AggregateFunc {
3333
public FieldReference column() { return column; }
3434

3535
@Override
36-
public String toString() { return "Max(" + column.describe() + ")"; }
36+
public String toString() { return "MAX(" + column.describe() + ")"; }
3737

3838
@Override
3939
public String describe() { return this.toString(); }

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public final class Min implements AggregateFunc {
3333
public FieldReference column() { return column; }
3434

3535
@Override
36-
public String toString() { return "Min(" + column.describe() + ")"; }
36+
public String toString() { return "MIN(" + column.describe() + ")"; }
3737

3838
@Override
3939
public String describe() { return this.toString(); }

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,23 @@
2828
@Evolving
2929
public final class Sum implements AggregateFunc {
3030
private final FieldReference column;
31-
private final DataType dataType;
3231
private final boolean isDistinct;
3332

34-
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
33+
public Sum(FieldReference column, boolean isDistinct) {
3534
this.column = column;
36-
this.dataType = dataType;
3735
this.isDistinct = isDistinct;
3836
}
3937

4038
public FieldReference column() { return column; }
41-
public DataType dataType() { return dataType; }
4239
public boolean isDistinct() { return isDistinct; }
4340

4441
@Override
4542
public String toString() {
46-
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
43+
if (isDistinct) {
44+
return "SUM(DISTINCT " + column.describe() + ")";
45+
} else {
46+
return "SUM(" + column.describe() + ")";
47+
}
4748
}
4849

4950
@Override

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,7 @@ object DataSourceStrategy
714714
case _ => None
715715
}
716716
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
717-
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
718-
sum.dataType, aggregates.isDistinct))
717+
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct))
719718
case _ => None
720719
}
721720
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.sql.catalyst.InternalRow
28-
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
28+
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum}
2929
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
3030
import org.apache.spark.sql.sources._
3131
import org.apache.spark.sql.types._
@@ -136,30 +136,30 @@ object JDBCRDD extends Logging {
136136

137137
def compileAggregates(
138138
aggregates: Seq[AggregateFunc],
139-
dialect: JdbcDialect): Seq[String] = {
139+
dialect: JdbcDialect): Option[Seq[String]] = {
140140
def quote(colName: String): String = dialect.quoteIdentifier(colName)
141141

142-
aggregates.map {
142+
Some(aggregates.map {
143143
case min: Min =>
144-
assert(min.column.fieldNames.length == 1)
144+
if (min.column.fieldNames.length != 1) return None
145145
s"MIN(${quote(min.column.fieldNames.head)})"
146146
case max: Max =>
147-
assert(max.column.fieldNames.length == 1)
147+
if (max.column.fieldNames.length != 1) return None
148148
s"MAX(${quote(max.column.fieldNames.head)})"
149149
case count: Count =>
150-
assert(count.column.fieldNames.length == 1)
151-
val distinct = if (count.isDistinct) "DISTINCT" else ""
150+
if (count.column.fieldNames.length != 1) return None
151+
val distinct = if (count.isDistinct) "DISTINCT " else ""
152152
val column = quote(count.column.fieldNames.head)
153-
s"COUNT($distinct $column)"
153+
s"COUNT($distinct$column)"
154154
case sum: Sum =>
155-
assert(sum.column.fieldNames.length == 1)
156-
val distinct = if (sum.isDistinct) "DISTINCT" else ""
155+
if (sum.column.fieldNames.length != 1) return None
156+
val distinct = if (sum.isDistinct) "DISTINCT " else ""
157157
val column = quote(sum.column.fieldNames.head)
158-
s"SUM($distinct $column)"
158+
s"SUM($distinct$column)"
159159
case _: CountStar =>
160-
s"COUNT(1)"
161-
case _ => ""
162-
}
160+
s"COUNT(*)"
161+
case _ => return None
162+
})
163163
}
164164

165165
/**
@@ -185,7 +185,7 @@ object JDBCRDD extends Logging {
185185
parts: Array[Partition],
186186
options: JDBCOptions,
187187
outputSchema: Option[StructType] = None,
188-
groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = {
188+
groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
189189
val url = options.url
190190
val dialect = JdbcDialects.get(url)
191191
val quotedColumns = if (groupByColumns.isEmpty) {
@@ -221,7 +221,7 @@ private[jdbc] class JDBCRDD(
221221
partitions: Array[Partition],
222222
url: String,
223223
options: JDBCOptions,
224-
groupByColumns: Option[Array[FieldReference]])
224+
groupByColumns: Option[Array[String]])
225225
extends RDD[InternalRow](sc, Nil) {
226226

227227
/**
@@ -266,10 +266,8 @@ private[jdbc] class JDBCRDD(
266266
*/
267267
private def getGroupByClause: String = {
268268
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
269-
assert(groupByColumns.get.forall(_.fieldNames.length == 1))
270-
val dialect = JdbcDialects.get(url)
271-
val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head))
272-
s"GROUP BY ${quotedColumns.mkString(", ")}"
269+
// The GROUP BY columns should already be quoted by the caller side.
270+
s"GROUP BY ${groupByColumns.get.mkString(", ")}"
273271
} else {
274272
""
275273
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
2727
import org.apache.spark.sql.catalyst.analysis._
2828
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
30-
import org.apache.spark.sql.connector.expressions.FieldReference
3130
import org.apache.spark.sql.errors.QueryCompilationErrors
3231
import org.apache.spark.sql.internal.SQLConf
3332
import org.apache.spark.sql.jdbc.JdbcDialects
@@ -291,9 +290,9 @@ private[sql] case class JDBCRelation(
291290

292291
def buildScan(
293292
requiredColumns: Array[String],
294-
requireSchema: Option[StructType],
293+
finalSchema: StructType,
295294
filters: Array[Filter],
296-
groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
295+
groupByColumns: Option[Array[String]]): RDD[Row] = {
297296
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
298297
JDBCRDD.scanTable(
299298
sparkSession.sparkContext,
@@ -302,7 +301,7 @@ private[sql] case class JDBCRelation(
302301
filters,
303302
parts,
304303
jdbcOptions,
305-
requireSchema,
304+
Some(finalSchema),
306305
groupByColumns).asInstanceOf[RDD[Row]]
307306
}
308307

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ object PushDownUtils extends PredicateHelper {
9191
}
9292

9393
scanBuilder match {
94-
case r: SupportsPushDownAggregates =>
94+
case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
9595
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
9696
val translatedGroupBys = groupBy.flatMap(columnAsString)
9797

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
2123
import org.apache.spark.sql.catalyst.expressions.aggregate
2224
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -76,9 +78,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
7678
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
7779
sHolder.builder match {
7880
case _: SupportsPushDownAggregates =>
81+
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
82+
var ordinal = 0
7983
val aggregates = resultExpressions.flatMap { expr =>
8084
expr.collect {
81-
case agg: AggregateExpression => agg
85+
// Do not push down duplicated aggregate expressions. For example,
86+
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
87+
// `sum(a)` to the data source.
88+
case agg: AggregateExpression
89+
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
90+
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
91+
ordinal += 1
92+
agg
8293
}
8394
}
8495
val pushedAggregates = PushDownUtils
@@ -144,19 +155,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
144155
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
145156
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
146157
// scalastyle:on
147-
var i = 0
148158
val aggOutput = output.drop(groupAttrs.length)
149159
plan.transformExpressions {
150160
case agg: AggregateExpression =>
161+
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
151162
val aggFunction: aggregate.AggregateFunction =
152163
agg.aggregateFunction match {
153-
case max: aggregate.Max => max.copy(child = aggOutput(i))
154-
case min: aggregate.Min => min.copy(child = aggOutput(i))
155-
case sum: aggregate.Sum => sum.copy(child = aggOutput(i))
156-
case _: aggregate.Count => aggregate.Sum(aggOutput(i))
164+
case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
165+
case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
166+
case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
167+
case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
157168
case other => other
158169
}
159-
i += 1
160170
agg.copy(aggregateFunction = aggFunction)
161171
}
162172
}

0 commit comments

Comments
 (0)