Skip to content

Commit 7ffcfcf

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-33847][SQL] Simplify CaseWhen if elseValue is None
### What changes were proposed in this pull request? 1. Enhance `ReplaceNullWithFalseInPredicate` to replace None of elseValue inside `CaseWhen` with `FalseLiteral` if all branches are `FalseLiteral` . The use case is: ```sql create table t1 using parquet as select id from range(10); explain select id from t1 where (CASE WHEN id = 1 THEN 'a' WHEN id = 3 THEN 'b' end) = 'c'; ``` Before this pr: ``` == Physical Plan == *(1) Filter CASE WHEN (id#1L = 1) THEN false WHEN (id#1L = 3) THEN false END +- *(1) ColumnarToRow +- FileScan parquet default.t1[id#1L] Batched: true, DataFilters: [CASE WHEN (id#1L = 1) THEN false WHEN (id#1L = 3) THEN false END], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/spark-warehouse/org.apache.spark.sql.DataF..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint> ``` After this pr: ``` == Physical Plan == LocalTableScan <empty>, [id#1L] ``` 2. Enhance `SimplifyConditionals` if elseValue is None and all outputs are null. ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30852 from wangyum/SPARK-33847. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 303df64 commit 7ffcfcf

File tree

6 files changed

+78
-7
lines changed

6 files changed

+78
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
9393
val newBranches = cw.branches.map { case (cond, value) =>
9494
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
9595
}
96-
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
97-
CaseWhen(newBranches, newElseValue)
96+
if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) {
97+
FalseLiteral
98+
} else {
99+
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
100+
CaseWhen(newBranches, newElseValue)
101+
}
98102
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
99103
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
100104
case e if e.dataType == BooleanType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
525525
} else {
526526
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
527527
}
528+
529+
case e @ CaseWhen(branches, None)
530+
if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) =>
531+
Literal(null, e.dataType)
528532
}
529533
}
530534
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,15 @@ class PushFoldableIntoBranchesSuite
258258
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")),
259259
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
260260
}
261+
262+
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
263+
Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition =>
264+
assertEquivalent(
265+
EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)),
266+
Literal.create(null, BooleanType))
267+
assertEquivalent(
268+
EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)),
269+
Literal.create(null, BooleanType))
270+
}
271+
}
261272
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,39 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
380380
testProjection(originalExpr = column, expectedExpr = column)
381381
}
382382

383+
test("replace None of elseValue inside CaseWhen if all branches are FalseLiteral") {
384+
val allFalseBranches = Seq(
385+
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
386+
(UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
387+
val allFalseCond = CaseWhen(allFalseBranches)
388+
389+
val nonAllFalseBranches = Seq(
390+
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
391+
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
392+
val nonAllFalseCond = CaseWhen(nonAllFalseBranches, FalseLiteral)
393+
394+
testFilter(allFalseCond, FalseLiteral)
395+
testJoin(allFalseCond, FalseLiteral)
396+
testDelete(allFalseCond, FalseLiteral)
397+
testUpdate(allFalseCond, FalseLiteral)
398+
399+
testFilter(nonAllFalseCond, nonAllFalseCond)
400+
testJoin(nonAllFalseCond, nonAllFalseCond)
401+
testDelete(nonAllFalseCond, nonAllFalseCond)
402+
testUpdate(nonAllFalseCond, nonAllFalseCond)
403+
}
404+
405+
test("replace None of elseValue inside CaseWhen if all branches are null") {
406+
val allNullBranches = Seq(
407+
(UnresolvedAttribute("i") < Literal(10)) -> Literal.create(null, BooleanType),
408+
(UnresolvedAttribute("i") > Literal(40)) -> Literal.create(null, BooleanType))
409+
val allFalseCond = CaseWhen(allNullBranches)
410+
testFilter(allFalseCond, FalseLiteral)
411+
testJoin(allFalseCond, FalseLiteral)
412+
testDelete(allFalseCond, FalseLiteral)
413+
testUpdate(allFalseCond, FalseLiteral)
414+
}
415+
383416
private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
384417
test((rel, exp) => rel.where(exp), originalCond, expectedCond)
385418
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,12 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
215215
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
216216
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
217217
}
218+
219+
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
220+
Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition =>
221+
assertEquivalent(
222+
CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None),
223+
Literal.create(null, IntegerType))
224+
}
225+
}
218226
}

sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ import org.apache.spark.sql.types.BooleanType
2727
class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSparkSession {
2828
import testImplicits._
2929

30+
private def checkPlanIsEmptyLocalScan(df: DataFrame): Unit =
31+
df.queryExecution.executedPlan match {
32+
case s: LocalTableScanExec => assert(s.rows.isEmpty)
33+
case p => fail(s"$p is not LocalTableScanExec")
34+
}
35+
3036
test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") {
3137
withTable("t1", "t2") {
3238
Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1")
@@ -64,11 +70,6 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared
6470

6571
checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true))
6672
}
67-
68-
def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = df.queryExecution.executedPlan match {
69-
case s: LocalTableScanExec => assert(s.rows.isEmpty)
70-
case p => fail(s"$p is not LocalTableScanExec")
71-
}
7273
}
7374

7475
test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") {
@@ -112,4 +113,14 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared
112113
assertNoLiteralNullInPlan(q3)
113114
}
114115
}
116+
117+
test("SPARK-33847: replace None of elseValue inside CaseWhen to FalseLiteral") {
118+
withTable("t1") {
119+
Seq((1, 1), (2, 2)).toDF("a", "b").write.saveAsTable("t1")
120+
val t1 = spark.table("t1")
121+
val q1 = t1.filter("(CASE WHEN a > 1 THEN 1 END) = 0")
122+
checkAnswer(q1, Seq.empty)
123+
checkPlanIsEmptyLocalScan(q1)
124+
}
125+
}
115126
}

0 commit comments

Comments
 (0)