diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 0defbf0b6fd7..0cc971a909bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -306,13 +306,13 @@ object OptimizeIn extends Rule[LogicalPlan] { } } - /** * Simplifies boolean expressions: * 1. Simplifies expressions whose answer can be determined without evaluating both sides. * 2. Eliminates / extracts common factors. * 3. Merge same expressions * 4. Removes `Not` operator. + * 5. Simplifies expression (key1==key2) || (isnull(key1)&&isnull(key2)) to key1 <=> key2 */ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( @@ -412,6 +412,16 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } + case Or(EqualTo(l, r), And(IsNull(c1), IsNull(c2))) + if (l.semanticEquals(c1) && r.semanticEquals(c2)) || + (l.semanticEquals(c2) && r.semanticEquals(c1)) => + EqualNullSafe(l, r) + + case Or(And(IsNull(c1), IsNull(c2)), EqualTo(l, r)) + if (l.semanticEquals(c1) && r.semanticEquals(c2)) || + (l.semanticEquals(c2) && r.semanticEquals(c1)) => + EqualNullSafe(l, r) + // Common factor elimination for disjunction case or @ (left Or right) => // 1. Split left and right to get the conjunctive predicates, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index fc2697d55f6d..2b4631db9e2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -63,6 +63,11 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper { comparePlans(actual, expected) } + private def checkCondition(input: LogicalPlan, expected: LogicalPlan): Unit = { + val actual = Optimize.execute(input) + comparePlans(actual, expected) + } + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze val actual = Optimize.execute(plan) @@ -278,6 +283,18 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper { testRelationWithData.where(Or($"e".isNotNull, Literal(null, BooleanType))).analyze) } + test("SPARK-40177: simplify (a == b) || (a == null and b == null) => a <=> b") { + checkCondition(testRelation.where(Or(EqualTo($"a", $"b"), + And($"a".isNull, $"b".isNull))).analyze, + testRelation.where(EqualNullSafe($"a", $"b")).analyze) + + checkCondition(testRelation.where(Or(And($"a".isNull, $"b".isNull), + EqualTo($"a", $"b"))).analyze, testRelation.where(EqualNullSafe($"a", $"b")).analyze) + + checkCondition(testRelation.where(Or(And($"a".isNull, $"b".isNull), + EqualTo($"b", $"a"))).analyze, testRelation.where(EqualNullSafe($"b", $"a")).analyze) + } + test("Complementation Laws - negative case") { checkCondition($"e" && !$"f", testRelationWithData.where($"e" && !$"f").analyze) checkCondition(!$"f" && $"e", testRelationWithData.where(!$"f" && $"e").analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyJoinConditionSuite.scala new file mode 100644 index 000000000000..1bf12b7523f4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyJoinConditionSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.{IsNotNull, IsNull} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class SimplifyJoinConditionSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Simplify Join Condition", Once, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int, 'e.int) + + test("Simple condition with null check on right side of or") { + val originalQuery = testRelation + .join(testRelation1, condition = Some(('b === 'd)||(IsNull('b) && IsNull('d)))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, condition = Some('b <=> 'd)) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Simple condition with null check on left side of or") { + val originalQuery = testRelation + .join(testRelation1, condition = Some((IsNull('b) && IsNull('d)) || ('b === 'd))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, condition = Some('b <=> 'd)) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Simple condition with is not null check on one column") { + val originalQuery = testRelation + .join(testRelation1, condition = Some((IsNull('b) && IsNotNull('d)) || ('b === 'd))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, condition = Some((IsNull('b) && IsNotNull('d)) || ('b === 'd))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("multiple equal null safe conditions separated by and") { + val originalQuery = testRelation.join(testRelation1, + condition = Some(((IsNull('b) && IsNull('d)) || ('b === 'd)) && + ((IsNull('a) && IsNull('e)) || ('a === 'e)))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, condition = Some('b <=> 'd && 'a <=> 'e)) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("multiple equal null safe conditions separated by or") { + val originalQuery = testRelation.join(testRelation1, + condition = Some(((IsNull('b) && IsNull('d)) || ('b === 'd)) || + ((IsNull('a) && IsNull('e)) || ('a === 'e)))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, condition = Some('b <=> 'd || 'a <=> 'e)) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Condition with another or in expression") { + val originalQuery = testRelation.join(testRelation1, + condition = Some((IsNull('b) && IsNull('d)) || ('b === 'd) || ('a === 'e))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation.join(testRelation1, + condition = Some('b <=> 'd || ('a === 'e))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Condition with another and in expression so that and gets calculated first") { + val originalQuery = testRelation.join(testRelation1, + condition = Some((IsNull('b) && IsNull('d)) || ('b === 'd) && ('a === 'e))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation.join(testRelation1, + condition = Some((IsNull('b) && IsNull('d)) || ('b === 'd) && ('a === 'e))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Condition with another and in expression") { + val originalQuery = testRelation.join(testRelation1, + condition = Some(((IsNull('b) && IsNull('d)) || ('b === 'd)) && ('a === 'e))) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation.join(testRelation1, + condition = Some('b <=> 'd && ('a === 'e))) + .analyze + comparePlans(optimized, correctAnswer) + } +}