Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression =>
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.execution.datasources.PushableExpression
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType}
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType}

/**
* The builder to generate V2 expressions from catalyst expressions.
Expand Down Expand Up @@ -98,45 +98,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
generateExpression(child).map(v => new V2Cast(v, dataType))
case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) =>
generateAggregateFunc(aggregateFunction, isDistinct)
case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
case Coalesce(children) => generateExpressionWithName("COALESCE", children)
case Greatest(children) => generateExpressionWithName("GREATEST", children)
case Least(children) => generateExpressionWithName("LEAST", children)
case Rand(child, hideSeed) =>
case Abs(_, true) => generateExpressionWithName("ABS", expr, isPredicate)
case _: Coalesce => generateExpressionWithName("COALESCE", expr, isPredicate)
case _: Greatest => generateExpressionWithName("GREATEST", expr, isPredicate)
case _: Least => generateExpressionWithName("LEAST", expr, isPredicate)
case Rand(_, hideSeed) =>
if (hideSeed) {
Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
} else {
generateExpressionWithName("RAND", Seq(child))
generateExpressionWithName("RAND", expr, isPredicate)
}
case log: Logarithm => generateExpressionWithName("LOG", log.children)
case Log10(child) => generateExpressionWithName("LOG10", Seq(child))
case Log2(child) => generateExpressionWithName("LOG2", Seq(child))
case Log(child) => generateExpressionWithName("LN", Seq(child))
case Exp(child) => generateExpressionWithName("EXP", Seq(child))
case pow: Pow => generateExpressionWithName("POWER", pow.children)
case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child))
case Floor(child) => generateExpressionWithName("FLOOR", Seq(child))
case Ceil(child) => generateExpressionWithName("CEIL", Seq(child))
case round: Round => generateExpressionWithName("ROUND", round.children)
case Sin(child) => generateExpressionWithName("SIN", Seq(child))
case Sinh(child) => generateExpressionWithName("SINH", Seq(child))
case Cos(child) => generateExpressionWithName("COS", Seq(child))
case Cosh(child) => generateExpressionWithName("COSH", Seq(child))
case Tan(child) => generateExpressionWithName("TAN", Seq(child))
case Tanh(child) => generateExpressionWithName("TANH", Seq(child))
case Cot(child) => generateExpressionWithName("COT", Seq(child))
case Asin(child) => generateExpressionWithName("ASIN", Seq(child))
case Asinh(child) => generateExpressionWithName("ASINH", Seq(child))
case Acos(child) => generateExpressionWithName("ACOS", Seq(child))
case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child))
case Atan(child) => generateExpressionWithName("ATAN", Seq(child))
case Atanh(child) => generateExpressionWithName("ATANH", Seq(child))
case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children)
case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child))
case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child))
case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child))
case Signum(child) => generateExpressionWithName("SIGN", Seq(child))
case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children)
case _: Logarithm => generateExpressionWithName("LOG", expr, isPredicate)
case _: Log10 => generateExpressionWithName("LOG10", expr, isPredicate)
case _: Log2 => generateExpressionWithName("LOG2", expr, isPredicate)
case _: Log => generateExpressionWithName("LN", expr, isPredicate)
case _: Exp => generateExpressionWithName("EXP", expr, isPredicate)
case _: Pow => generateExpressionWithName("POWER", expr, isPredicate)
case _: Sqrt => generateExpressionWithName("SQRT", expr, isPredicate)
case _: Floor => generateExpressionWithName("FLOOR", expr, isPredicate)
case _: Ceil => generateExpressionWithName("CEIL", expr, isPredicate)
case _: Round => generateExpressionWithName("ROUND", expr, isPredicate)
case _: Sin => generateExpressionWithName("SIN", expr, isPredicate)
case _: Sinh => generateExpressionWithName("SINH", expr, isPredicate)
case _: Cos => generateExpressionWithName("COS", expr, isPredicate)
case _: Cosh => generateExpressionWithName("COSH", expr, isPredicate)
case _: Tan => generateExpressionWithName("TAN", expr, isPredicate)
case _: Tanh => generateExpressionWithName("TANH", expr, isPredicate)
case _: Cot => generateExpressionWithName("COT", expr, isPredicate)
case _: Asin => generateExpressionWithName("ASIN", expr, isPredicate)
case _: Asinh => generateExpressionWithName("ASINH", expr, isPredicate)
case _: Acos => generateExpressionWithName("ACOS", expr, isPredicate)
case _: Acosh => generateExpressionWithName("ACOSH", expr, isPredicate)
case _: Atan => generateExpressionWithName("ATAN", expr, isPredicate)
case _: Atanh => generateExpressionWithName("ATANH", expr, isPredicate)
case _: Atan2 => generateExpressionWithName("ATAN2", expr, isPredicate)
case _: Cbrt => generateExpressionWithName("CBRT", expr, isPredicate)
case _: ToDegrees => generateExpressionWithName("DEGREES", expr, isPredicate)
case _: ToRadians => generateExpressionWithName("RADIANS", expr, isPredicate)
case _: Signum => generateExpressionWithName("SIGN", expr, isPredicate)
case _: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", expr, isPredicate)
case and: And =>
// AND expects predicate
val l = generateExpression(and.left, true)
Expand Down Expand Up @@ -187,57 +187,56 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
assert(v.isInstanceOf[V2Predicate])
new V2Not(v.asInstanceOf[V2Predicate])
}
case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child))
case BitwiseNot(child) => generateExpressionWithName("~", Seq(child))
case CaseWhen(branches, elseValue) =>
case UnaryMinus(_, true) => generateExpressionWithName("-", expr, isPredicate)
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
case caseWhen @ CaseWhen(branches, elseValue) =>
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reserved isPredicate=true for conditions of casewhen

val values = branches.map(_._2).flatMap(generateExpression(_, true))
if (conditions.length == branches.length && values.length == branches.length) {
val values = branches.map(_._2).flatMap(generateExpression(_))
val elseExprOpt = elseValue.flatMap(generateExpression(_))
if (conditions.length == branches.length && values.length == branches.length &&
elseExprOpt.size == elseValue.size) {
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
Seq[V2Expression](c, v)
}
if (elseValue.isDefined) {
elseValue.flatMap(generateExpression(_)).map { v =>
val children = (branchExpressions :+ v).toArray[V2Expression]
// The children looks like [condition1, value1, ..., conditionN, valueN, elseValue]
new V2Predicate("CASE_WHEN", children)
}
val children = (branchExpressions ++ elseExprOpt).toArray[V2Expression]
// The children looks like [condition1, value1, ..., conditionN, valueN (, elseValue)]
if (isPredicate && caseWhen.dataType.isInstanceOf[BooleanType]) {
Some(new V2Predicate("CASE_WHEN", children))
} else {
// The children looks like [condition1, value1, ..., conditionN, valueN]
Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression]))
Some(new GeneralScalarExpression("CASE_WHEN", children))
}
} else {
None
}
case iff: If => generateExpressionWithName("CASE_WHEN", iff.children)
case _: If => generateExpressionWithName("CASE_WHEN", expr, isPredicate)
case substring: Substring =>
val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
Seq(substring.str, substring.pos)
} else {
substring.children
}
generateExpressionWithName("SUBSTRING", children)
case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
generateExpressionWithNameByChildren("SUBSTRING", children, substring.dataType, isPredicate)
case _: Upper => generateExpressionWithName("UPPER", expr, isPredicate)
case _: Lower => generateExpressionWithName("LOWER", expr, isPredicate)
case BitLength(child) if child.dataType.isInstanceOf[StringType] =>
generateExpressionWithName("BIT_LENGTH", Seq(child))
generateExpressionWithName("BIT_LENGTH", expr, isPredicate)
case Length(child) if child.dataType.isInstanceOf[StringType] =>
generateExpressionWithName("CHAR_LENGTH", Seq(child))
case concat: Concat => generateExpressionWithName("CONCAT", concat.children)
case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children)
case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children)
case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children)
generateExpressionWithName("CHAR_LENGTH", expr, isPredicate)
case _: Concat => generateExpressionWithName("CONCAT", expr, isPredicate)
case _: StringTranslate => generateExpressionWithName("TRANSLATE", expr, isPredicate)
case _: StringTrim => generateExpressionWithName("TRIM", expr, isPredicate)
case _: StringTrimLeft => generateExpressionWithName("LTRIM", expr, isPredicate)
case _: StringTrimRight => generateExpressionWithName("RTRIM", expr, isPredicate)
case overlay: Overlay =>
val children = if (overlay.len == Literal(-1)) {
Seq(overlay.input, overlay.replace, overlay.pos)
} else {
overlay.children
}
generateExpressionWithName("OVERLAY", children)
case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children)
case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children)
case date: TruncDate => generateExpressionWithName("TRUNC", date.children)
generateExpressionWithNameByChildren("OVERLAY", children, overlay.dataType, isPredicate)
case _: DateAdd => generateExpressionWithName("DATE_ADD", expr, isPredicate)
case _: DateDiff => generateExpressionWithName("DATE_DIFF", expr, isPredicate)
case _: TruncDate => generateExpressionWithName("TRUNC", expr, isPredicate)
case Second(child, _) =>
generateExpression(child).map(v => new V2Extract("SECOND", v))
case Minute(child, _) =>
Expand Down Expand Up @@ -270,12 +269,12 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
generateExpression(child).map(v => new V2Extract("WEEK", v))
case YearOfWeek(child) =>
generateExpression(child).map(v => new V2Extract("YEAR_OF_WEEK", v))
case encrypt: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", encrypt.children)
case decrypt: AesDecrypt => generateExpressionWithName("AES_DECRYPT", decrypt.children)
case Crc32(child) => generateExpressionWithName("CRC32", Seq(child))
case Md5(child) => generateExpressionWithName("MD5", Seq(child))
case Sha1(child) => generateExpressionWithName("SHA1", Seq(child))
case sha2: Sha2 => generateExpressionWithName("SHA2", sha2.children)
case _: AesEncrypt => generateExpressionWithName("AES_ENCRYPT", expr, isPredicate)
case _: AesDecrypt => generateExpressionWithName("AES_DECRYPT", expr, isPredicate)
case _: Crc32 => generateExpressionWithName("CRC32", expr, isPredicate)
case _: Md5 => generateExpressionWithName("MD5", expr, isPredicate)
case _: Sha1 => generateExpressionWithName("SHA1", expr, isPredicate)
case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate)
// TODO supports other expressions
case ApplyFunctionExpression(function, children) =>
val childrenExpressions = children.flatMap(generateExpression(_))
Expand Down Expand Up @@ -380,10 +379,26 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
}

private def generateExpressionWithName(
v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = {
v2ExpressionName: String,
expr: Expression,
isPredicate: Boolean): Option[V2Expression] = {
generateExpressionWithNameByChildren(
v2ExpressionName, expr.children, expr.dataType, isPredicate)
}

private def generateExpressionWithNameByChildren(
v2ExpressionName: String,
children: Seq[Expression],
dataType: DataType,
isPredicate: Boolean): Option[V2Expression] = {
val childrenExpressions = children.flatMap(generateExpression(_))
if (childrenExpressions.length == children.length) {
Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
} else {
Some(new GeneralScalarExpression(
v2ExpressionName, childrenExpressions.toArray[V2Expression]))
}
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,44 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
)
}
}

test("SPARK-47463: Pushed down v2 filter with (if / case when / nvl) expression") {
withTempView("t1") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
spark.read.format(classOf[AdvancedDataSourceV2WithV2Filter].getName).load()
.createTempView("t1")
val df1 = sql(
s"""
|select * from
|(select if(i = 1, i, 0) as c from t1) t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is sufficient. If can be a pedicate and before this PR we don't return V2Predicate which causes errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is sufficient. If can be a pedicate and before this PR we don't return V2Predicate which causes errors.

makes sense to me

|where t.c > 0
|""".stripMargin
)
val result1 = df1.collect()
assert(result1.length == 1)

val df2 = sql(
s"""
|select * from
|(select case when i = 1 then i else 0 end as c from t1) t
|where t.c > 0
|""".stripMargin
)
val result2 = df2.collect()
assert(result2.length == 1)

val df3 = sql(
s"""
|select * from
|(select nvl(cast(i as boolean), false) c from t1) t
|where t.c is true
|""".stripMargin
)
val result3 = df3.collect()
assert(result3.length > 0)
}
}
}
}

case class RangeInputPartition(start: Int, end: Int) extends InputPartition
Expand Down