From d5e4905502e8797dd1f5f6faa586abe40c3aa7c4 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 23 Mar 2022 14:11:27 +0800 Subject: [PATCH 1/5] [SPARK-38633][SQL] Support push down Cast to JDBC data source V2 --- .../spark/sql/connector/expressions/Cast.java | 39 +++++++++++++++++++ .../expressions/GeneralScalarExpression.java | 3 +- .../util/V2ExpressionSQLBuilder.java | 10 +++++ .../catalyst/util/V2ExpressionBuilder.scala | 6 ++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 21 +++++++++- 5 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java new file mode 100644 index 0000000000000..d3bc140009933 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.DataType; + +/** + * Represents a cast expression in the public logical expression API. + * + * @since 3.3.0 + */ +@Evolving +public class Cast extends GeneralScalarExpression { + private DataType dataType; + + public Cast(Expression expression, DataType dataType) { + super("CAST", new Expression[]{ expression }); + this.dataType = dataType; + } + + public Expression expression() { return children()[0]; } + public DataType dataType() { return dataType; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 8952761f9ef34..3c06f768fa5f4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -22,13 +22,14 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Cast; import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; /** * The general representation of SQL scalar expressions, which contains the upper-cased * expression name and all the children expressions. Please also see {@link Predicate} - * for the supported predicate expressions. + * for the supported predicate expressions and {@link Cast} for the cast expression. *

* The currently supported SQL scalar expressions: *

    diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 1df01d29cbdd1..af621e847a04b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -21,10 +21,12 @@ import java.util.List; import java.util.stream.Collectors; +import org.apache.spark.sql.connector.expressions.Cast; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; /** * The builder to generate SQL from V2 expressions. @@ -80,6 +82,10 @@ public String build(Expression expr) { return visitBinaryArithmetic( name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); } + case "CAST": + assert e instanceof Cast; + Cast cast = (Cast) e; + return visitCast(name, build(cast.expression()), cast.dataType()); case "AND": return visitAnd(name, build(e.children()[0]), build(e.children()[1])); case "OR": @@ -167,6 +173,10 @@ protected String visitBinaryArithmetic(String name, String l, String r) { return l + " " + name + " " + r; } + protected String visitCast(String name, String l, DataType dataType) { + return name + "( " + l + " AS " + dataType.typeName() + " )"; + } + protected String visitAnd(String name, String l, String r) { return "(" + l + ") " + name + " (" + r + ")"; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index fbd6884358b0a..5fd01ac5636b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn import org.apache.spark.sql.types.BooleanType @@ -93,6 +93,8 @@ class V2ExpressionBuilder( } else { None } + case Cast(child, dataType, _, true) => + generateExpression(child).map(v => new V2Cast(v, dataType)) case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index afbdc604b8a18..d9d50e031c868 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -352,7 +352,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) } - test("scan with complex filter push-down") { + test("scan with filter push-down with ansi mode") { Seq(false, true).foreach { ansiMode => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val df = spark.table("h2.test.people").filter($"id" + 1 > 1) @@ -404,6 +404,25 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df3, expectedPlanFragment3) checkAnswer(df3, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df4 = spark.table("h2.test.employee") + .filter(($"salary" > 1000d).and($"salary" < 12000d)) + + checkFiltersRemoved(df4, ansiMode) + + df4.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = if (ansiMode) { + "PushedFilters: [SALARY IS NOT NULL, " + + "CAST( SALARY AS double ) > 1000.0, CAST( SALARY AS double ) < 12000.0], " + } else { + "PushedFilters: [SALARY IS NOT NULL], " + } + checkKeywordsExistsInExplain(df4, expected_plan_fragment) + } + + checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), + Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) } } } From ef8b9a413193781f1f0a2ffb777aed53826ca0f1 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 23 Mar 2022 15:45:23 +0800 Subject: [PATCH 2/5] Update code --- .../spark/sql/connector/expressions/GeneralScalarExpression.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 3c06f768fa5f4..87032e2b32e9f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -22,7 +22,6 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Cast; import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; From 6e50730ac69dce1ae8d942e538dd2fb7b684f415 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 23 Mar 2022 18:39:42 +0800 Subject: [PATCH 3/5] Update code --- .../java/org/apache/spark/sql/connector/expressions/Cast.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index d3bc140009933..058960b12b35d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -23,7 +23,7 @@ /** * Represents a cast expression in the public logical expression API. * - * @since 3.3.0 + * @since 3.4.0 */ @Evolving public class Cast extends GeneralScalarExpression { From 9e58bb40432398a4b6c2aec2b40c8903162412fb Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 28 Mar 2022 14:17:26 +0800 Subject: [PATCH 4/5] Update code --- .../spark/sql/connector/expressions/Cast.java | 14 ++++++++++---- .../sql/connector/util/V2ExpressionSQLBuilder.java | 11 +++++------ .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 2 +- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index 058960b12b35d..26b97b46fe2ef 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -17,23 +17,29 @@ package org.apache.spark.sql.connector.expressions; +import java.io.Serializable; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.DataType; /** * Represents a cast expression in the public logical expression API. * - * @since 3.4.0 + * @since 3.3.0 */ @Evolving -public class Cast extends GeneralScalarExpression { +public class Cast implements Expression, Serializable { + private Expression expression; private DataType dataType; public Cast(Expression expression, DataType dataType) { - super("CAST", new Expression[]{ expression }); + this.expression = expression; this.dataType = dataType; } - public Expression expression() { return children()[0]; } + public Expression expression() { return expression; } public DataType dataType() { return dataType; } + + @Override + public Expression[] children() { return new Expression[]{ expression() }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index af621e847a04b..c8d924db75aed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -38,6 +38,9 @@ public String build(Expression expr) { return visitLiteral((Literal) expr); } else if (expr instanceof NamedReference) { return visitNamedReference((NamedReference) expr); + } else if (expr instanceof Cast) { + Cast cast = (Cast) expr; + return visitCast(build(cast.expression()), cast.dataType()); } else if (expr instanceof GeneralScalarExpression) { GeneralScalarExpression e = (GeneralScalarExpression) expr; String name = e.name(); @@ -82,10 +85,6 @@ public String build(Expression expr) { return visitBinaryArithmetic( name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); } - case "CAST": - assert e instanceof Cast; - Cast cast = (Cast) e; - return visitCast(name, build(cast.expression()), cast.dataType()); case "AND": return visitAnd(name, build(e.children()[0]), build(e.children()[1])); case "OR": @@ -173,8 +172,8 @@ protected String visitBinaryArithmetic(String name, String l, String r) { return l + " " + name + " " + r; } - protected String visitCast(String name, String l, DataType dataType) { - return name + "( " + l + " AS " + dataType.typeName() + " )"; + protected String visitCast(String l, DataType dataType) { + return "CAST(" + l + " AS " + dataType.typeName() + ")"; } protected String visitAnd(String name, String l, String r) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index d9d50e031c868..0dd9503473c6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -414,7 +414,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = if (ansiMode) { "PushedFilters: [SALARY IS NOT NULL, " + - "CAST( SALARY AS double ) > 1000.0, CAST( SALARY AS double ) < 12000.0], " + "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " } else { "PushedFilters: [SALARY IS NOT NULL], " } From 6985249f2f49cb114061ebfb245e2e362863059f Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 28 Mar 2022 16:55:11 +0800 Subject: [PATCH 5/5] Update code --- .../sql/connector/expressions/GeneralScalarExpression.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 87032e2b32e9f..8952761f9ef34 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -28,7 +28,7 @@ /** * The general representation of SQL scalar expressions, which contains the upper-cased * expression name and all the children expressions. Please also see {@link Predicate} - * for the supported predicate expressions and {@link Cast} for the cast expression. + * for the supported predicate expressions. *

    * The currently supported SQL scalar expressions: *