diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 8952761f9ef3..58082d5ee09c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -94,6 +94,60 @@ *
  • Since version: 3.3.0
  • * * + *
  • Name: ABS + * + *
  • + *
  • Name: COALESCE + * + *
  • + *
  • Name: LN + * + *
  • + *
  • Name: EXP + * + *
  • + *
  • Name: POWER + * + *
  • + *
  • Name: SQRT + * + *
  • + *
  • Name: FLOOR + * + *
  • + *
  • Name: CEIL + * + *
  • + *
  • Name: WIDTH_BUCKET + * + *
  • * * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index a7d1ed7f85e8..c9dfa2003e3c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -95,6 +95,13 @@ public String build(Expression expr) { return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); case "ABS": case "COALESCE": + case "LN": + case "EXP": + case "POWER": + case "SQRT": + case "FLOOR": + case "CEIL": + case "WIDTH_BUCKET": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index b6b64804904e..f743844ebd81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2380,4 +2380,8 @@ object QueryCompilationErrors { new AnalysisException( "Sinks cannot request distribution and ordering in continuous execution mode") } + + def noSuchFunctionError(database: String, funcInfo: String): Throwable = { + new AnalysisException(s"$database does not support function: $funcInfo") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 37db499470aa..487b809d48a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -104,6 +104,32 @@ class V2ExpressionBuilder( } else { None } + case Log(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) + case Exp(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) + case Pow(left, right) => + val l = generateExpression(left) + val r = generateExpression(right) + if (l.isDefined && r.isDefined) { + Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) + } else { + None + } + case Sqrt(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) + case Floor(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) + case Ceil(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) + case wb: WidthBucket => + val childrenExpressions = wb.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == wb.children.length) { + Some(new GeneralScalarExpression("WIDTH_BUCKET", + childrenExpressions.toArray[V2Expression])) + } else { + None + } case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 643376cdb126..0aa971c0d3ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.errors.QueryCompilationErrors private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "WIDTH_BUCKET" => + val functionInfo = super.visitSQLFunction(funcName, inputs) + throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo) + case _ => super.visitSQLFunction(funcName, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index e60af877e9c5..5cfa2f465a2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when} +import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -464,6 +464,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df5, expectedPlanFragment5) checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee") + .filter(ln($"dept") > 1) + .filter(exp($"salary") > 2000) + .filter(pow($"dept", 2) > 4) + .filter(sqrt($"salary") > 100) + .filter(floor($"dept") > 1) + .filter(ceil($"dept") > 1) + checkFiltersRemoved(df6, ansiMode) + val expectedPlanFragment6 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " + + "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...," + } else { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]" + } + checkPushedInfo(df6, expectedPlanFragment6) + checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true))) + + // H2 does not support width_bucket + val df7 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE width_bucket(dept, 1, 6, 3) > 1 + |""".stripMargin) + checkFiltersRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]") + checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) } } }