Skip to content

Commit 44a5583

Browse files
committed
simplification
1 parent 5523402 commit 44a5583

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.connector.expressions;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.types.DataType;
2221

2322
/**
2423
* An aggregate function that returns the summation of all the values in a group.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ object JDBCRDD extends Logging {
5454
val url = options.url
5555
val table = options.tableOrQuery
5656
val dialect = JdbcDialects.get(url)
57+
getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect)
58+
}
59+
60+
def getQueryOutputSchema(
61+
query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = {
5762
val conn: Connection = JdbcUtils.createConnectionFactory(options)()
5863
try {
59-
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
64+
val statement = conn.prepareStatement(query)
6065
try {
6166
statement.setQueryTimeout(options.queryTimeout)
6267
val rs = statement.executeQuery()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
8484
expr.collect {
8585
// Do not push down duplicated aggregate expressions. For example,
8686
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
87-
// `sum(a)` to the data source.
87+
// `max(a)` to the data source.
8888
case agg: AggregateExpression
8989
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
9090
aggExprToOutputOrdinal(agg.canonicalized) = ordinal

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,10 @@ case class JDBCScanBuilder(
8484
"GROUP BY " + groupByCols.mkString(",")
8585
}
8686

87-
val aggQuery = s"SELECT ${selectList.mkString(",")} FROM " +
88-
s"${jdbcOptions.tableOrQuery} $groupByClause"
89-
val jdbcOptionsWithAggQuery = new JDBCOptions(
90-
jdbcOptions.parameters
91-
- JDBCOptions.JDBC_TABLE_NAME
92-
- JDBCOptions.JDBC_PARTITION_COLUMN
93-
- JDBCOptions.JDBC_NUM_PARTITIONS
94-
- JDBCOptions.JDBC_LOWER_BOUND
95-
- JDBCOptions.JDBC_UPPER_BOUND +
96-
(JDBCOptions.JDBC_QUERY_STRING -> aggQuery))
87+
val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
88+
s"WHERE 1=0 $groupByClause"
9789
try {
98-
finalSchema = JDBCRDD.resolveTable(jdbcOptionsWithAggQuery)
90+
finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
9991
pushedAggregateList = selectList
10092
pushedGroupByCols = Some(groupByCols)
10193
true

0 commit comments

Comments
 (0)