Skip to content

Commit 773c823

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-25714][BACKPORT-2.2] Fix Null Handling in the Optimizer rule BooleanSimplification
This PR is to backport #22702 to branch 2.2. --- ## What changes were proposed in this pull request? ```Scala val df1 = Seq(("abc", 1), (null, 3)).toDF("col1", "col2") df1.write.mode(SaveMode.Overwrite).parquet("/tmp/test1") val df2 = spark.read.parquet("/tmp/test1") df2.filter("col1 = 'abc' OR (col1 != 'abc' AND col2 == 3)").show() ``` Before the PR, it returns both rows. After the fix, it returns `Row ("abc", 1))`. This is to fix the bug in NULL handling in BooleanSimplification. This is a bug introduced in Spark 1.6 release. ## How was this patch tested? Added test cases Closes #22719 from gatorsmile/cherrypickSpark-257142.2. Authored-by: gatorsmile <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 9a74cb3 commit 773c823

3 files changed

Lines changed: 153 additions & 33 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ case class Not(child: Expression)
120120

121121
override def inputTypes: Seq[DataType] = Seq(BooleanType)
122122

123+
// +---------+-----------+
124+
// | CHILD | NOT CHILD |
125+
// +---------+-----------+
126+
// | TRUE | FALSE |
127+
// | FALSE | TRUE |
128+
// | UNKNOWN | UNKNOWN |
129+
// +---------+-----------+
123130
protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean]
124131

125132
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -331,6 +338,13 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
331338

332339
override def sqlOperator: String = "AND"
333340

341+
// +---------+---------+---------+---------+
342+
// | AND | TRUE | FALSE | UNKNOWN |
343+
// +---------+---------+---------+---------+
344+
// | TRUE | TRUE | FALSE | UNKNOWN |
345+
// | FALSE | FALSE | FALSE | FALSE |
346+
// | UNKNOWN | UNKNOWN | FALSE | UNKNOWN |
347+
// +---------+---------+---------+---------+
334348
override def eval(input: InternalRow): Any = {
335349
val input1 = left.eval(input)
336350
if (input1 == false) {
@@ -433,6 +447,13 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
433447

434448
override def sqlOperator: String = "OR"
435449

450+
// +---------+---------+---------+---------+
451+
// | OR | TRUE | FALSE | UNKNOWN |
452+
// +---------+---------+---------+---------+
453+
// | TRUE | TRUE | TRUE | TRUE |
454+
// | FALSE | TRUE | FALSE | UNKNOWN |
455+
// | UNKNOWN | TRUE | UNKNOWN | UNKNOWN |
456+
// +---------+---------+---------+---------+
436457
override def eval(input: InternalRow): Any = {
437458
val input1 = left.eval(input)
438459
if (input1 == true) {
@@ -583,6 +604,13 @@ case class EqualTo(left: Expression, right: Expression)
583604

584605
override def symbol: String = "="
585606

607+
// +---------+---------+---------+---------+
608+
// | = | TRUE | FALSE | UNKNOWN |
609+
// +---------+---------+---------+---------+
610+
// | TRUE | TRUE | FALSE | UNKNOWN |
611+
// | FALSE | FALSE | TRUE | UNKNOWN |
612+
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
613+
// +---------+---------+---------+---------+
586614
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
587615

588616
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -618,6 +646,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
618646

619647
override def nullable: Boolean = false
620648

649+
// +---------+---------+---------+---------+
650+
// | <=> | TRUE | FALSE | UNKNOWN |
651+
// +---------+---------+---------+---------+
652+
// | TRUE | TRUE | FALSE | UNKNOWN |
653+
// | FALSE | FALSE | TRUE | UNKNOWN |
654+
// | UNKNOWN | UNKNOWN | UNKNOWN | TRUE |
655+
// +---------+---------+---------+---------+
621656
override def eval(input: InternalRow): Any = {
622657
val input1 = left.eval(input)
623658
val input2 = right.eval(input)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,37 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
167167
case a And b if a.semanticEquals(b) => a
168168
case a Or b if a.semanticEquals(b) => a
169169

170-
case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
171-
case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
172-
case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
173-
case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
174-
175-
case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
176-
case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
177-
case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
178-
case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
170+
// The following optimizations are applicable only when the operands are not nullable,
171+
// since the three-value logic of AND and OR are different in NULL handling.
172+
// See the chart:
173+
// +---------+---------+---------+---------+
174+
// | operand | operand | OR | AND |
175+
// +---------+---------+---------+---------+
176+
// | TRUE | TRUE | TRUE | TRUE |
177+
// | TRUE | FALSE | TRUE | FALSE |
178+
// | FALSE | FALSE | FALSE | FALSE |
179+
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
180+
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
181+
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
182+
// +---------+---------+---------+---------+
183+
184+
// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
185+
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
186+
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
187+
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
188+
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
189+
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
190+
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
191+
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)
192+
193+
// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
194+
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
195+
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
196+
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
197+
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
198+
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
199+
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
200+
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)
179201

180202
// Common factor elimination for conjunction
181203
case and @ (left And right) =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.types.BooleanType
3131

32-
class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
32+
class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper {
3333

3434
object Optimize extends RuleExecutor[LogicalPlan] {
3535
val batches =
@@ -71,6 +71,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
7171
comparePlans(actual, correctAnswer)
7272
}
7373

74+
private def checkConditionInNotNullableRelation(
75+
input: Expression, expected: Expression): Unit = {
76+
val plan = testNotNullableRelationWithData.where(input).analyze
77+
val actual = Optimize.execute(plan)
78+
val correctAnswer = testNotNullableRelationWithData.where(expected).analyze
79+
comparePlans(actual, correctAnswer)
80+
}
81+
7482
private def checkConditionInNotNullableRelation(
7583
input: Expression, expected: LogicalPlan): Unit = {
7684
val plan = testNotNullableRelationWithData.where(input).analyze
@@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
119127
'a === 'b || 'b > 3 && 'a > 3 && 'a < 5)
120128
}
121129

122-
test("a && (!a || b)") {
123-
checkCondition('a && (!'a || 'b ), 'a && 'b)
130+
test("e && (!e || f) - not nullable") {
131+
checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f)
124132

125-
checkCondition('a && ('b || !'a ), 'a && 'b)
133+
checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f)
126134

127-
checkCondition((!'a || 'b ) && 'a, 'b && 'a)
135+
checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e)
128136

129-
checkCondition(('b || !'a ) && 'a, 'b && 'a)
137+
checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e)
130138
}
131139

132-
test("a < 1 && (!(a < 1) || b)") {
133-
checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b)
134-
checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b)
140+
test("e && (!e || f) - nullable") {
141+
Seq ('e && (!'e || 'f ),
142+
'e && ('f || !'e ),
143+
(!'e || 'f ) && 'e,
144+
('f || !'e ) && 'e,
145+
'e || (!'e && 'f),
146+
'e || ('f && !'e),
147+
('e && 'f) || !'e,
148+
('f && 'e) || !'e).foreach { expr =>
149+
checkCondition(expr, expr)
150+
}
151+
}
135152

136-
checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b)
137-
checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b)
153+
test("a < 1 && (!(a < 1) || f) - not nullable") {
154+
checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
155+
checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)
138156

139-
checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b)
140-
checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b)
157+
checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
158+
checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)
141159

142-
checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b)
143-
checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b)
160+
checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
161+
checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)
162+
163+
checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
164+
checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
144165
}
145166

146-
test("a < 1 && ((a >= 1) || b)") {
147-
checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b)
148-
checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b)
167+
test("a < 1 && ((a >= 1) || f) - not nullable") {
168+
checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
169+
checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)
149170

150-
checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b)
151-
checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b)
171+
checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
172+
checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)
152173

153-
checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b)
154-
checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b)
174+
checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
175+
checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)
155176

156-
checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b)
157-
checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b)
177+
checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
178+
checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
158179
}
159180

160181
test("DeMorgan's law") {
@@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
217238
checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze)
218239
checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze)
219240
}
241+
242+
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
243+
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
244+
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
245+
comparePlans(actual, correctAnswer)
246+
}
247+
248+
test("filter reduction - positive cases") {
249+
val fields = Seq(
250+
'col1NotNULL.boolean.notNull,
251+
'col2NotNULL.boolean.notNull
252+
)
253+
val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) }
254+
255+
val exprs = Seq(
256+
// actual expressions of the transformations: original -> transformed
257+
(col1NotNULL && (!col1NotNULL || col2NotNULL)) -> (col1NotNULL && col2NotNULL),
258+
(col1NotNULL && (col2NotNULL || !col1NotNULL)) -> (col1NotNULL && col2NotNULL),
259+
((!col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
260+
((col2NotNULL || !col1NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
261+
262+
(col1NotNULL || (!col1NotNULL && col2NotNULL)) -> (col1NotNULL || col2NotNULL),
263+
(col1NotNULL || (col2NotNULL && !col1NotNULL)) -> (col1NotNULL || col2NotNULL),
264+
((!col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL),
265+
((col2NotNULL && !col1NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL)
266+
)
267+
268+
// check plans
269+
for ((originalExpr, expectedExpr) <- exprs) {
270+
assertEquivalent(originalExpr, expectedExpr)
271+
}
272+
273+
// check evaluation
274+
val binaryBooleanValues = Seq(true, false)
275+
for (col1NotNULLVal <- binaryBooleanValues;
276+
col2NotNULLVal <- binaryBooleanValues;
277+
(originalExpr, expectedExpr) <- exprs) {
278+
val inputRow = create_row(col1NotNULLVal, col2NotNULLVal)
279+
val optimizedVal = evaluate(expectedExpr, inputRow)
280+
checkEvaluation(originalExpr, optimizedVal, inputRow)
281+
}
282+
}
220283
}

0 commit comments

Comments
 (0)