@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
2929import org .apache .spark .sql .internal .SQLConf
3030import org .apache .spark .sql .types .BooleanType
3131
32- class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
32+ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper {
3333
3434 object Optimize extends RuleExecutor [LogicalPlan ] {
3535 val batches =
@@ -71,6 +71,14 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
7171 comparePlans(actual, correctAnswer)
7272 }
7373
74+ private def checkConditionInNotNullableRelation (
75+ input : Expression , expected : Expression ): Unit = {
76+ val plan = testNotNullableRelationWithData.where(input).analyze
77+ val actual = Optimize .execute(plan)
78+ val correctAnswer = testNotNullableRelationWithData.where(expected).analyze
79+ comparePlans(actual, correctAnswer)
80+ }
81+
7482 private def checkConditionInNotNullableRelation (
7583 input : Expression , expected : LogicalPlan ): Unit = {
7684 val plan = testNotNullableRelationWithData.where(input).analyze
@@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
119127 ' a === ' b || ' b > 3 && ' a > 3 && ' a < 5 )
120128 }
121129
122- test(" a && (!a || b) " ) {
123- checkCondition( ' a && (! ' a || ' b ), ' a && ' b )
130+ test(" e && (!e || f) - not nullable " ) {
131+ checkConditionInNotNullableRelation( ' e && (! ' e || ' f ), ' e && ' f )
124132
125- checkCondition( ' a && (' b || ! ' a ), ' a && ' b )
133+ checkConditionInNotNullableRelation( ' e && (' f || ! ' e ), ' e && ' f )
126134
127- checkCondition ((! ' a || ' b ) && ' a , ' b && ' a )
135+ checkConditionInNotNullableRelation ((! ' e || ' f ) && ' e , ' f && ' e )
128136
129- checkCondition ((' b || ! ' a ) && ' a , ' b && ' a )
137+ checkConditionInNotNullableRelation ((' f || ! ' e ) && ' e , ' f && ' e )
130138 }
131139
132- test(" a < 1 && (!(a < 1) || b)" ) {
133- checkCondition(' a < 1 && (! (' a < 1 ) || ' b ), (' a < 1 ) && ' b )
134- checkCondition(' a < 1 && (' b || ! (' a < 1 )), (' a < 1 ) && ' b )
140+ test(" e && (!e || f) - nullable" ) {
141+ Seq (' e && (! ' e || ' f ),
142+ ' e && (' f || ! ' e ),
143+ (! ' e || ' f ) && ' e ,
144+ (' f || ! ' e ) && ' e ,
145+ ' e || (! ' e && ' f ),
146+ ' e || (' f && ! ' e ),
147+ (' e && ' f ) || ! ' e ,
148+ (' f && ' e ) || ! ' e ).foreach { expr =>
149+ checkCondition(expr, expr)
150+ }
151+ }
135152
136- checkCondition(' a <= 1 && (! (' a <= 1 ) || ' b ), (' a <= 1 ) && ' b )
137- checkCondition(' a <= 1 && (' b || ! (' a <= 1 )), (' a <= 1 ) && ' b )
153+ test(" a < 1 && (!(a < 1) || f) - not nullable" ) {
154+ checkConditionInNotNullableRelation(' a < 1 && (! (' a < 1 ) || ' f ), (' a < 1 ) && ' f )
155+ checkConditionInNotNullableRelation(' a < 1 && (' f || ! (' a < 1 )), (' a < 1 ) && ' f )
138156
139- checkCondition (' a > 1 && (! (' a > 1 ) || ' b ), (' a > 1 ) && ' b )
140- checkCondition (' a > 1 && (' b || ! (' a > 1 )), (' a > 1 ) && ' b )
157+ checkConditionInNotNullableRelation (' a <= 1 && (! (' a <= 1 ) || ' f ), (' a <= 1 ) && ' f )
158+ checkConditionInNotNullableRelation (' a <= 1 && (' f || ! (' a <= 1 )), (' a <= 1 ) && ' f )
141159
142- checkCondition(' a >= 1 && (! (' a >= 1 ) || ' b ), (' a >= 1 ) && ' b )
143- checkCondition(' a >= 1 && (' b || ! (' a >= 1 )), (' a >= 1 ) && ' b )
160+ checkConditionInNotNullableRelation(' a > 1 && (! (' a > 1 ) || ' f ), (' a > 1 ) && ' f )
161+ checkConditionInNotNullableRelation(' a > 1 && (' f || ! (' a > 1 )), (' a > 1 ) && ' f )
162+
163+ checkConditionInNotNullableRelation(' a >= 1 && (! (' a >= 1 ) || ' f ), (' a >= 1 ) && ' f )
164+ checkConditionInNotNullableRelation(' a >= 1 && (' f || ! (' a >= 1 )), (' a >= 1 ) && ' f )
144165 }
145166
146- test(" a < 1 && ((a >= 1) || b) " ) {
147- checkCondition (' a < 1 && (' a >= 1 || ' b ), (' a < 1 ) && ' b )
148- checkCondition (' a < 1 && (' b || ' a >= 1 ), (' a < 1 ) && ' b )
167+ test(" a < 1 && ((a >= 1) || f) - not nullable " ) {
168+ checkConditionInNotNullableRelation (' a < 1 && (' a >= 1 || ' f ), (' a < 1 ) && ' f )
169+ checkConditionInNotNullableRelation (' a < 1 && (' f || ' a >= 1 ), (' a < 1 ) && ' f )
149170
150- checkCondition (' a <= 1 && (' a > 1 || ' b ), (' a <= 1 ) && ' b )
151- checkCondition (' a <= 1 && (' b || ' a > 1 ), (' a <= 1 ) && ' b )
171+ checkConditionInNotNullableRelation (' a <= 1 && (' a > 1 || ' f ), (' a <= 1 ) && ' f )
172+ checkConditionInNotNullableRelation (' a <= 1 && (' f || ' a > 1 ), (' a <= 1 ) && ' f )
152173
153- checkCondition (' a > 1 && ((' a <= 1 ) || ' b ), (' a > 1 ) && ' b )
154- checkCondition (' a > 1 && (' b || (' a <= 1 )), (' a > 1 ) && ' b )
174+ checkConditionInNotNullableRelation (' a > 1 && ((' a <= 1 ) || ' f ), (' a > 1 ) && ' f )
175+ checkConditionInNotNullableRelation (' a > 1 && (' f || (' a <= 1 )), (' a > 1 ) && ' f )
155176
156- checkCondition (' a >= 1 && ((' a < 1 ) || ' b ), (' a >= 1 ) && ' b )
157- checkCondition (' a >= 1 && (' b || (' a < 1 )), (' a >= 1 ) && ' b )
177+ checkConditionInNotNullableRelation (' a >= 1 && ((' a < 1 ) || ' f ), (' a >= 1 ) && ' f )
178+ checkConditionInNotNullableRelation (' a >= 1 && (' f || (' a < 1 )), (' a >= 1 ) && ' f )
158179 }
159180
160181 test(" DeMorgan's law" ) {
@@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
217238 checkCondition(' e || ! ' f , testRelationWithData.where(' e || ! ' f ).analyze)
218239 checkCondition(! ' f || ' e , testRelationWithData.where(! ' f || ' e ).analyze)
219240 }
241+
242+ protected def assertEquivalent (e1 : Expression , e2 : Expression ): Unit = {
243+ val correctAnswer = Project (Alias (e2, " out" )() :: Nil , OneRowRelation ).analyze
244+ val actual = Optimize .execute(Project (Alias (e1, " out" )() :: Nil , OneRowRelation ).analyze)
245+ comparePlans(actual, correctAnswer)
246+ }
247+
248+ test(" filter reduction - positive cases" ) {
249+ val fields = Seq (
250+ ' col1NotNULL .boolean.notNull,
251+ ' col2NotNULL .boolean.notNull
252+ )
253+ val Seq (col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) }
254+
255+ val exprs = Seq (
256+ // actual expressions of the transformations: original -> transformed
257+ (col1NotNULL && (! col1NotNULL || col2NotNULL)) -> (col1NotNULL && col2NotNULL),
258+ (col1NotNULL && (col2NotNULL || ! col1NotNULL)) -> (col1NotNULL && col2NotNULL),
259+ ((! col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
260+ ((col2NotNULL || ! col1NotNULL) && col1NotNULL) -> (col2NotNULL && col1NotNULL),
261+
262+ (col1NotNULL || (! col1NotNULL && col2NotNULL)) -> (col1NotNULL || col2NotNULL),
263+ (col1NotNULL || (col2NotNULL && ! col1NotNULL)) -> (col1NotNULL || col2NotNULL),
264+ ((! col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL),
265+ ((col2NotNULL && ! col1NotNULL) || col1NotNULL) -> (col2NotNULL || col1NotNULL)
266+ )
267+
268+ // check plans
269+ for ((originalExpr, expectedExpr) <- exprs) {
270+ assertEquivalent(originalExpr, expectedExpr)
271+ }
272+
273+ // check evaluation
274+ val binaryBooleanValues = Seq (true , false )
275+ for (col1NotNULLVal <- binaryBooleanValues;
276+ col2NotNULLVal <- binaryBooleanValues;
277+ (originalExpr, expectedExpr) <- exprs) {
278+ val inputRow = create_row(col1NotNULLVal, col2NotNULLVal)
279+ val optimizedVal = evaluate(expectedExpr, inputRow)
280+ checkEvaluation(originalExpr, optimizedVal, inputRow)
281+ }
282+ }
220283}
0 commit comments