Skip to content

Commit 2f3997f

Browse files
yeshengmgatorsmile
authored andcommitted
[SPARK-28306][SQL][FOLLOWUP] Fix NormalizeFloatingNumbers rule idempotence for equi-join with <=> predicates
## What changes were proposed in this pull request? Idempotence of the `NormalizeFloatingNumbers` rule was broken due to the implementation of `ExtractEquiJoinKeys`. There is no reason that we don't remove `EqualNullSafe` join keys from an equi-join's `otherPredicates`. ## How was this patch tested? A new UT. Closes #25126 from yeshengm/spark-28306. Authored-by: Yesheng Ma <kimi.ysma@gmail.com> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
1 parent 8d1e87a commit 2f3997f

2 files changed

Lines changed: 24 additions & 7 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,23 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
118118
// Replace null with default value for joining key, then those rows with null in it could
119119
// be joined together
120120
case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
121-
Some((Coalesce(Seq(l, Literal.default(l.dataType))),
122-
Coalesce(Seq(r, Literal.default(r.dataType)))))
121+
Seq((Coalesce(Seq(l, Literal.default(l.dataType))),
122+
Coalesce(Seq(r, Literal.default(r.dataType)))),
123+
(IsNull(l), IsNull(r))
124+
)
123125
case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
124-
Some((Coalesce(Seq(r, Literal.default(r.dataType))),
125-
Coalesce(Seq(l, Literal.default(l.dataType)))))
126+
Seq((Coalesce(Seq(r, Literal.default(r.dataType))),
127+
Coalesce(Seq(l, Literal.default(l.dataType)))),
128+
(IsNull(r), IsNull(l))
129+
)
126130
case other => None
127131
}
128132
val otherPredicates = predicates.filterNot {
129133
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
130-
case EqualTo(l, r) =>
134+
case Equality(l, r) =>
131135
canEvaluate(l, left) && canEvaluate(r, right) ||
132136
canEvaluate(l, right) && canEvaluate(r, left)
133-
case other => false
137+
case _ => false
134138
}
135139

136140
if (joinKeys.nonEmpty) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
22-
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
22+
import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized}
2323
import org.apache.spark.sql.catalyst.plans.PlanTest
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -78,5 +78,18 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
7878

7979
comparePlans(doubleOptimized, correctAnswer)
8080
}
81+
82+
test("normalize floating points in join keys (equal null safe) - idempotence") {
83+
val query = testRelation1.join(testRelation2, condition = Some(a <=> b))
84+
85+
val optimized = Optimize.execute(query)
86+
val doubleOptimized = Optimize.execute(optimized)
87+
val joinCond = IsNull(a) === IsNull(b) &&
88+
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) ===
89+
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0)))
90+
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))
91+
92+
comparePlans(doubleOptimized, correctAnswer)
93+
}
8194
}
8295

0 commit comments

Comments
 (0)