Skip to content

Commit f7bdea3

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-33884][SQL] Simplify CaseWhenclauses with (true and false) and (false and true)
### What changes were proposed in this pull request? This pr simplify `CaseWhen`clauses with (true and false) and (false and true): Expression | cond.nullable | After simplify -- | -- | -- case when cond then true else false end | true | cond <=> true case when cond then true else false end | false | cond case when cond then false else true end | true | !(cond <=> true) case when cond then false else true end | false | !cond ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30898 from wangyum/SPARK-33884. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 379afcd commit f7bdea3

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
486486
case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l)
487487
case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l)
488488

489+
case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) =>
490+
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
491+
case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) =>
492+
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
493+
489494
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
490495
// If there are branches that are always false, remove them.
491496
// If there are no more branches left, just use the else value.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class PushFoldableIntoBranchesSuite
141141
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2)))
142142
assert(!nonDeterministic.deterministic)
143143
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
144-
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral)))
144+
GreaterThanOrEqual(Rand(1), Literal(0.5)))
145145
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
146146
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral)))
147147

@@ -269,4 +269,13 @@ class PushFoldableIntoBranchesSuite
269269
Literal.create(null, BooleanType))
270270
}
271271
}
272+
273+
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
274+
assertEquivalent(
275+
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)),
276+
'a > 10 <=> TrueLiteral)
277+
assertEquivalent(
278+
EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)),
279+
Not('a > 10 <=> TrueLiteral))
280+
}
272281
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,4 +243,40 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
243243
Literal.create(null, IntegerType))
244244
}
245245
}
246+
247+
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
248+
// verify the boolean equivalence of all transformations involved
249+
val fields = Seq(
250+
'cond.boolean.notNull,
251+
'cond_nullable.boolean,
252+
'a.boolean,
253+
'b.boolean
254+
)
255+
val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) }
256+
257+
val exprs = Seq(
258+
// actual expressions of the transformations: original -> transformed
259+
CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral) -> cond,
260+
CaseWhen(Seq((cond, FalseLiteral)), TrueLiteral) -> !cond,
261+
CaseWhen(Seq((cond_nullable, TrueLiteral)), FalseLiteral) -> (cond_nullable <=> true),
262+
CaseWhen(Seq((cond_nullable, FalseLiteral)), TrueLiteral) -> (!(cond_nullable <=> true)))
263+
264+
// check plans
265+
for ((originalExpr, expectedExpr) <- exprs) {
266+
assertEquivalent(originalExpr, expectedExpr)
267+
}
268+
269+
// check evaluation
270+
val binaryBooleanValues = Seq(true, false)
271+
val ternaryBooleanValues = Seq(true, false, null)
272+
for (condVal <- binaryBooleanValues;
273+
condNullableVal <- ternaryBooleanValues;
274+
aVal <- ternaryBooleanValues;
275+
bVal <- ternaryBooleanValues;
276+
(originalExpr, expectedExpr) <- exprs) {
277+
val inputRow = create_row(condVal, condNullableVal, aVal, bVal)
278+
val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
279+
checkEvaluation(originalExpr, optimizedVal, inputRow)
280+
}
281+
}
246282
}

0 commit comments

Comments
 (0)