Skip to content

Commit 724fb08

Browse files
chenzhxcloud-fan
authored andcommitted
[SPARK-38897][SQL] DS V2 supports push down string functions
### What changes were proposed in this pull request? Currently, Spark have some string functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L503 These functions show below: `SUBSTRING,` `UPPER`, `LOWER`, `TRANSLATE`, `TRIM`, `OVERLAY` The mainstream databases support these functions show below. Function | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift | Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer | Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata | Singlestore | ElasticSearch -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- `SUBSTRING` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes `UPPER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes `LOWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | YES | Yes | Yes | Yes | Yes `TRIM` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes `TRANSLATE` | Yes | No | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | No | No | No | No | No | No `OVERLAY` | Yes | No | No | No | Yes | No | No | No | No | Yes | Yes | No | No | No | No | No | No | No | No | No | No | No DS V2 should supports push down these string functions. ### Why are the changes needed? DS V2 supports push down string functions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36330 from chenzhx/spark-master. Authored-by: chenzhx <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent a0decfc commit 724fb08

File tree

6 files changed

+211
-2
lines changed

6 files changed

+211
-2
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,54 @@
148148
* <li>Since version: 3.3.0</li>
149149
* </ul>
150150
* </li>
151+
* <li>Name: <code>SUBSTRING</code>
152+
* <ul>
153+
* <li>SQL semantic: <code>SUBSTRING(str, pos[, len])</code></li>
154+
* <li>Since version: 3.4.0</li>
155+
* </ul>
156+
* </li>
157+
* <li>Name: <code>UPPER</code>
158+
* <ul>
159+
* <li>SQL semantic: <code>UPPER(expr)</code></li>
160+
* <li>Since version: 3.4.0</li>
161+
* </ul>
162+
* </li>
163+
* <li>Name: <code>LOWER</code>
164+
* <ul>
165+
* <li>SQL semantic: <code>LOWER(expr)</code></li>
166+
* <li>Since version: 3.4.0</li>
167+
* </ul>
168+
* </li>
169+
* <li>Name: <code>TRANSLATE</code>
170+
* <ul>
171+
* <li>SQL semantic: <code>TRANSLATE(input, from, to)</code></li>
172+
* <li>Since version: 3.4.0</li>
173+
* </ul>
174+
* </li>
175+
* <li>Name: <code>TRIM</code>
176+
* <ul>
177+
* <li>SQL semantic: <code>TRIM(src, trim)</code></li>
178+
* <li>Since version: 3.4.0</li>
179+
* </ul>
180+
* </li>
181+
* <li>Name: <code>LTRIM</code>
182+
* <ul>
183+
* <li>SQL semantic: <code>LTRIM(src, trim)</code></li>
184+
* <li>Since version: 3.4.0</li>
185+
* </ul>
186+
* </li>
187+
* <li>Name: <code>RTRIM</code>
188+
* <ul>
189+
* <li>SQL semantic: <code>RTRIM(src, trim)</code></li>
190+
* <li>Since version: 3.4.0</li>
191+
* </ul>
192+
* </li>
193+
* <li>Name: <code>OVERLAY</code>
194+
* <ul>
195+
* <li>SQL semantic: <code>OVERLAY(string, replace, position[, length])</code></li>
196+
* <li>Since version: 3.4.0</li>
197+
* </ul>
198+
* </li>
151199
* </ol>
152200
* Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
153201
* including: add, subtract, multiply, divide, remainder, pmod.

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,29 @@ public String build(Expression expr) {
102102
case "FLOOR":
103103
case "CEIL":
104104
case "WIDTH_BUCKET":
105+
case "SUBSTRING":
106+
case "UPPER":
107+
case "LOWER":
108+
case "TRANSLATE":
105109
return visitSQLFunction(name,
106110
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
107111
case "CASE_WHEN": {
108112
List<String> children =
109113
Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
110114
return visitCaseWhen(children.toArray(new String[e.children().length]));
111115
}
116+
case "TRIM":
117+
return visitTrim("BOTH",
118+
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
119+
case "LTRIM":
120+
return visitTrim("LEADING",
121+
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
122+
case "RTRIM":
123+
return visitTrim("TRAILING",
124+
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
125+
case "OVERLAY":
126+
return visitOverlay(
127+
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
112128
// TODO supports other expressions
113129
default:
114130
return visitUnexpectedExpr(expr);
@@ -228,4 +244,23 @@ protected String visitSQLFunction(String funcName, String[] inputs) {
228244
protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
229245
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
230246
}
247+
248+
protected String visitOverlay(String[] inputs) {
249+
assert(inputs.length == 3 || inputs.length == 4);
250+
if (inputs.length == 3) {
251+
return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + ")";
252+
} else {
253+
return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] +
254+
" FOR " + inputs[3]+ ")";
255+
}
256+
}
257+
258+
protected String visitTrim(String direction, String[] inputs) {
259+
assert(inputs.length == 1 || inputs.length == 2);
260+
if (inputs.length == 1) {
261+
return "TRIM(" + direction + " FROM " + inputs[0] + ")";
262+
} else {
263+
return "TRIM(" + direction + " " + inputs[1] + " FROM " + inputs[0] + ")";
264+
}
265+
}
231266
}

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
20+
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Lower, Multiply, Not, Or, Overlay, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, StringTranslate, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, UnaryMinus, Upper, WidthBucket}
2121
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
2222
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
2323
import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -200,6 +200,65 @@ class V2ExpressionBuilder(
200200
} else {
201201
None
202202
}
203+
case substring: Substring =>
204+
val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
205+
Seq(substring.str, substring.pos)
206+
} else {
207+
substring.children
208+
}
209+
val childrenExpressions = children.flatMap(generateExpression(_))
210+
if (childrenExpressions.length == children.length) {
211+
Some(new GeneralScalarExpression("SUBSTRING",
212+
childrenExpressions.toArray[V2Expression]))
213+
} else {
214+
None
215+
}
216+
case Upper(child) => generateExpression(child)
217+
.map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v)))
218+
case Lower(child) => generateExpression(child)
219+
.map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v)))
220+
case translate: StringTranslate =>
221+
val childrenExpressions = translate.children.flatMap(generateExpression(_))
222+
if (childrenExpressions.length == translate.children.length) {
223+
Some(new GeneralScalarExpression("TRANSLATE",
224+
childrenExpressions.toArray[V2Expression]))
225+
} else {
226+
None
227+
}
228+
case trim: StringTrim =>
229+
val childrenExpressions = trim.children.flatMap(generateExpression(_))
230+
if (childrenExpressions.length == trim.children.length) {
231+
Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression]))
232+
} else {
233+
None
234+
}
235+
case trim: StringTrimLeft =>
236+
val childrenExpressions = trim.children.flatMap(generateExpression(_))
237+
if (childrenExpressions.length == trim.children.length) {
238+
Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression]))
239+
} else {
240+
None
241+
}
242+
case trim: StringTrimRight =>
243+
val childrenExpressions = trim.children.flatMap(generateExpression(_))
244+
if (childrenExpressions.length == trim.children.length) {
245+
Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression]))
246+
} else {
247+
None
248+
}
249+
case overlay: Overlay =>
250+
val children = if (overlay.len == Literal(-1)) {
251+
Seq(overlay.input, overlay.replace, overlay.pos)
252+
} else {
253+
overlay.children
254+
}
255+
val childrenExpressions = children.flatMap(generateExpression(_))
256+
if (childrenExpressions.length == children.length) {
257+
Some(new GeneralScalarExpression("OVERLAY",
258+
childrenExpressions.toArray[V2Expression]))
259+
} else {
260+
None
261+
}
203262
// TODO supports other expressions
204263
case _ => None
205264
}

sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ private object H2Dialect extends JdbcDialect {
3131
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
3232

3333
private val supportedFunctions =
34-
Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL")
34+
Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL",
35+
"SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "TRIM")
3536

3637
override def isSupportedFunction(funcName: String): Boolean =
3738
supportedFunctions.contains(funcName)

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,24 @@ abstract class JdbcDialect extends Serializable with Logging{
251251
s"${this.getClass.getSimpleName} does not support function: $funcName")
252252
}
253253
}
254+
255+
override def visitOverlay(inputs: Array[String]): String = {
256+
if (isSupportedFunction("OVERLAY")) {
257+
super.visitOverlay(inputs)
258+
} else {
259+
throw new UnsupportedOperationException(
260+
s"${this.getClass.getSimpleName} does not support function: OVERLAY")
261+
}
262+
}
263+
264+
override def visitTrim(direction: String, inputs: Array[String]): String = {
265+
if (isSupportedFunction("TRIM")) {
266+
super.visitTrim(direction, inputs)
267+
} else {
268+
throw new UnsupportedOperationException(
269+
s"${this.getClass.getSimpleName} does not support function: TRIM")
270+
}
271+
}
254272
}
255273

256274
/**

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,54 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
675675
}
676676
}
677677

678+
test("scan with filter push-down with string functions") {
679+
val df1 = sql("select * FROM h2.test.employee where " +
680+
"substr(name, 2, 1) = 'e'" +
681+
" AND upper(name) = 'JEN' AND lower(name) = 'jen' ")
682+
checkFiltersRemoved(df1)
683+
val expectedPlanFragment1 =
684+
"PushedFilters: [NAME IS NOT NULL, (SUBSTRING(NAME, 2, 1)) = 'e', " +
685+
"UPPER(NAME) = 'JEN', LOWER(NAME) = 'jen']"
686+
checkPushedInfo(df1, expectedPlanFragment1)
687+
checkAnswer(df1, Seq(Row(6, "jen", 12000, 1200, true)))
688+
689+
val df2 = sql("select * FROM h2.test.employee where " +
690+
"trim(name) = 'jen' AND trim('j', name) = 'en'" +
691+
"AND translate(name, 'e', 1) = 'j1n'")
692+
checkFiltersRemoved(df2)
693+
val expectedPlanFragment2 =
694+
"PushedFilters: [NAME IS NOT NULL, TRIM(BOTH FROM NAME) = 'jen', " +
695+
"(TRIM(BOTH 'j' FROM NAME)) = 'en', (TRANSLATE(NA..."
696+
checkPushedInfo(df2, expectedPlanFragment2)
697+
checkAnswer(df2, Seq(Row(6, "jen", 12000, 1200, true)))
698+
699+
val df3 = sql("select * FROM h2.test.employee where " +
700+
"ltrim(name) = 'jen' AND ltrim('j', name) = 'en'")
701+
checkFiltersRemoved(df3)
702+
val expectedPlanFragment3 =
703+
"PushedFilters: [TRIM(LEADING FROM NAME) = 'jen', " +
704+
"(TRIM(LEADING 'j' FROM NAME)) = 'en']"
705+
checkPushedInfo(df3, expectedPlanFragment3)
706+
checkAnswer(df3, Seq(Row(6, "jen", 12000, 1200, true)))
707+
708+
val df4 = sql("select * FROM h2.test.employee where " +
709+
"rtrim(name) = 'jen' AND rtrim('n', name) = 'je'")
710+
checkFiltersRemoved(df4)
711+
val expectedPlanFragment4 =
712+
"PushedFilters: [TRIM(TRAILING FROM NAME) = 'jen', " +
713+
"(TRIM(TRAILING 'n' FROM NAME)) = 'je']"
714+
checkPushedInfo(df4, expectedPlanFragment4)
715+
checkAnswer(df4, Seq(Row(6, "jen", 12000, 1200, true)))
716+
717+
// H2 does not support OVERLAY
718+
val df5 = sql("select * FROM h2.test.employee where OVERLAY(NAME, '1', 2, 1) = 'j1n'")
719+
checkFiltersRemoved(df5, false)
720+
val expectedPlanFragment5 =
721+
"PushedFilters: [NAME IS NOT NULL]"
722+
checkPushedInfo(df5, expectedPlanFragment5)
723+
checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true)))
724+
}
725+
678726
test("scan with aggregate push-down: MAX AVG with filter and group by") {
679727
val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" +
680728
" group by DePt")

0 commit comments

Comments
 (0)