diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index 8952761f9ef3..58082d5ee09c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -94,6 +94,60 @@
*
Since version: 3.3.0
*
*
+ * Name: ABS
+ *
+ * - SQL semantic:
ABS(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: COALESCE
+ *
+ * - SQL semantic:
COALESCE(expr1, expr2)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: LN
+ *
+ * - SQL semantic:
LN(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: EXP
+ *
+ * - SQL semantic:
EXP(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: POWER
+ *
+ * - SQL semantic:
POWER(expr, number)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: SQRT
+ *
+ * - SQL semantic:
SQRT(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: FLOOR
+ *
+ * - SQL semantic:
FLOOR(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: CEIL
+ *
+ * - SQL semantic:
CEIL(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * Name: WIDTH_BUCKET
+ *
+ * - SQL semantic:
WIDTH_BUCKET(expr)
+ * - Since version: 3.3.0
+ *
+ *
*
* Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
* including: add, subtract, multiply, divide, remainder, pmod.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index a7d1ed7f85e8..c9dfa2003e3c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -95,6 +95,13 @@ public String build(Expression expr) {
return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
case "ABS":
case "COALESCE":
+ case "LN":
+ case "EXP":
+ case "POWER":
+ case "SQRT":
+ case "FLOOR":
+ case "CEIL":
+ case "WIDTH_BUCKET":
return visitSQLFunction(name,
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
case "CASE_WHEN": {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index b6b64804904e..f743844ebd81 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2380,4 +2380,8 @@ object QueryCompilationErrors {
new AnalysisException(
"Sinks cannot request distribution and ordering in continuous execution mode")
}
+
+ def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
+ new AnalysisException(s"$database does not support function: $funcInfo")
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 37db499470aa..487b809d48a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -104,6 +104,32 @@ class V2ExpressionBuilder(
} else {
None
}
+ case Log(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
+ case Exp(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
+ case Pow(left, right) =>
+ val l = generateExpression(left)
+ val r = generateExpression(right)
+ if (l.isDefined && r.isDefined) {
+ Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get)))
+ } else {
+ None
+ }
+ case Sqrt(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
+ case Floor(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
+ case Ceil(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
+ case wb: WidthBucket =>
+ val childrenExpressions = wb.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == wb.children.length) {
+ Some(new GeneralScalarExpression("WIDTH_BUCKET",
+ childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
case and: And =>
// AND expects predicate
val l = generateExpression(and.left, true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 643376cdb126..0aa971c0d3ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc
import java.sql.SQLException
import java.util.Locale
+import scala.util.control.NonFatal
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
+import org.apache.spark.sql.errors.QueryCompilationErrors
private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
+ class H2SQLBuilder extends JDBCSQLBuilder {
+ override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
+ funcName match {
+ case "WIDTH_BUCKET" =>
+ val functionInfo = super.visitSQLFunction(funcName, inputs)
+ throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo)
+ case _ => super.visitSQLFunction(funcName, inputs)
+ }
+ }
+ }
+
+ override def compileExpression(expr: Expression): Option[String] = {
+ val h2SQLBuilder = new H2SQLBuilder()
+ try {
+ Some(h2SQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
aggFunction match {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index e60af877e9c5..5cfa2f465a2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when}
+import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -464,6 +464,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkPushedInfo(df5, expectedPlanFragment5)
checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true)))
+
+ val df6 = spark.table("h2.test.employee")
+ .filter(ln($"dept") > 1)
+ .filter(exp($"salary") > 2000)
+ .filter(pow($"dept", 2) > 4)
+ .filter(sqrt($"salary") > 100)
+ .filter(floor($"dept") > 1)
+ .filter(ceil($"dept") > 1)
+ checkFiltersRemoved(df6, ansiMode)
+ val expectedPlanFragment6 = if (ansiMode) {
+ "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " +
+ "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...,"
+ } else {
+ "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]"
+ }
+ checkPushedInfo(df6, expectedPlanFragment6)
+ checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ // H2 does not support width_bucket
+ val df7 = sql("""
+ |SELECT * FROM h2.test.employee
+ |WHERE width_bucket(dept, 1, 6, 3) > 1
+ |""".stripMargin)
+ checkFiltersRemoved(df7, false)
+ checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]")
+ checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true)))
}
}
}