Skip to content

Commit 13ca642

Browse files
aokolnychyidongjoon-hyun
authored andcommitted
[SPARK-33736][SQL] Handle MERGE in ReplaceNullWithFalseInPredicate (apache#908)
### What changes were proposed in this pull request? This PR adds MERGE operations to `ReplaceNullWithFalseInPredicate`. ### Why are the changes needed? These changes are needed to optimize conditions of MERGE operations and match the existing logic for UPDATE and DELETE. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR comes with new tests.
1 parent bb235ce commit 13ca642

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
21-
import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
20+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If, LambdaFunction, Literal, MapFilter, Or}
2221
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
23-
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable}
22+
import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable}
2423
import org.apache.spark.sql.catalyst.rules.Rule
2524
import org.apache.spark.sql.types.BooleanType
2625
import org.apache.spark.util.Utils
@@ -55,6 +54,11 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
5554
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond)))
5655
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
5756
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
57+
case m @ MergeIntoTable(_, _, cond, matchedActions, notMatchedActions) =>
58+
m.copy(
59+
mergeCondition = replaceNullWithFalse(cond),
60+
matchedActions = replaceNullWithFalse(matchedActions),
61+
notMatchedActions = replaceNullWithFalse(notMatchedActions))
5862
case p: LogicalPlan => p transformExpressions {
5963
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
6064
case cw @ CaseWhen(branches, _) =>
@@ -109,4 +113,13 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
109113
e
110114
}
111115
}
116+
117+
private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
118+
mergeActions.map {
119+
case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
120+
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
121+
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))
122+
case other => other
123+
}
124+
}
112125
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
2424
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
2525
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2626
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
27-
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
27+
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, DeleteFromTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, UpdateAction, UpdateTable}
2828
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.types.{BooleanType, IntegerType}
@@ -50,6 +50,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
5050
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
5151
testDelete(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
5252
testUpdate(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
53+
testMerge(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
5354
}
5455

5556
test("Not expected type - replaceNullWithFalse") {
@@ -68,6 +69,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
6869
testJoin(originalCond, expectedCond = FalseLiteral)
6970
testDelete(originalCond, expectedCond = FalseLiteral)
7071
testUpdate(originalCond, expectedCond = FalseLiteral)
72+
testMerge(originalCond, expectedCond = FalseLiteral)
7173
}
7274

7375
test("replace nulls in nested expressions in branches of If") {
@@ -79,6 +81,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
7981
testJoin(originalCond, expectedCond = FalseLiteral)
8082
testDelete(originalCond, expectedCond = FalseLiteral)
8183
testUpdate(originalCond, expectedCond = FalseLiteral)
84+
testMerge(originalCond, expectedCond = FalseLiteral)
8285
}
8386

8487
test("replace null in elseValue of CaseWhen") {
@@ -91,6 +94,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
9194
testJoin(originalCond, expectedCond)
9295
testDelete(originalCond, expectedCond)
9396
testUpdate(originalCond, expectedCond)
97+
testMerge(originalCond, expectedCond)
9498
}
9599

96100
test("replace null in branch values of CaseWhen") {
@@ -102,6 +106,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
102106
testJoin(originalCond, expectedCond = FalseLiteral)
103107
testDelete(originalCond, expectedCond = FalseLiteral)
104108
testUpdate(originalCond, expectedCond = FalseLiteral)
109+
testMerge(originalCond, expectedCond = FalseLiteral)
105110
}
106111

107112
test("replace null in branches of If inside CaseWhen") {
@@ -120,6 +125,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
120125
testJoin(originalCond, expectedCond)
121126
testDelete(originalCond, expectedCond)
122127
testUpdate(originalCond, expectedCond)
128+
testMerge(originalCond, expectedCond)
123129
}
124130

125131
test("replace null in complex CaseWhen expressions") {
@@ -141,6 +147,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
141147
testJoin(originalCond, expectedCond)
142148
testDelete(originalCond, expectedCond)
143149
testUpdate(originalCond, expectedCond)
150+
testMerge(originalCond, expectedCond)
144151
}
145152

146153
test("replace null in Or") {
@@ -150,6 +157,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
150157
testJoin(originalCond, expectedCond)
151158
testDelete(originalCond, expectedCond)
152159
testUpdate(originalCond, expectedCond)
160+
testMerge(originalCond, expectedCond)
153161
}
154162

155163
test("replace null in And") {
@@ -158,6 +166,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
158166
testJoin(originalCond, expectedCond = FalseLiteral)
159167
testDelete(originalCond, expectedCond = FalseLiteral)
160168
testUpdate(originalCond, expectedCond = FalseLiteral)
169+
testMerge(originalCond, expectedCond = FalseLiteral)
161170
}
162171

163172
test("replace nulls in nested And/Or expressions") {
@@ -168,6 +177,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
168177
testJoin(originalCond, expectedCond = FalseLiteral)
169178
testDelete(originalCond, expectedCond = FalseLiteral)
170179
testUpdate(originalCond, expectedCond = FalseLiteral)
180+
testMerge(originalCond, expectedCond = FalseLiteral)
171181
}
172182

173183
test("replace null in And inside branches of If") {
@@ -179,6 +189,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
179189
testJoin(originalCond, expectedCond = FalseLiteral)
180190
testDelete(originalCond, expectedCond = FalseLiteral)
181191
testUpdate(originalCond, expectedCond = FalseLiteral)
192+
testMerge(originalCond, expectedCond = FalseLiteral)
182193
}
183194

184195
test("replace null in branches of If inside And") {
@@ -192,6 +203,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
192203
testJoin(originalCond, expectedCond = FalseLiteral)
193204
testDelete(originalCond, expectedCond = FalseLiteral)
194205
testUpdate(originalCond, expectedCond = FalseLiteral)
206+
testMerge(originalCond, expectedCond = FalseLiteral)
195207
}
196208

197209
test("replace null in branches of If inside another If") {
@@ -203,6 +215,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
203215
testJoin(originalCond, expectedCond = FalseLiteral)
204216
testDelete(originalCond, expectedCond = FalseLiteral)
205217
testUpdate(originalCond, expectedCond = FalseLiteral)
218+
testMerge(originalCond, expectedCond = FalseLiteral)
206219
}
207220

208221
test("replace null in CaseWhen inside another CaseWhen") {
@@ -212,6 +225,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
212225
testJoin(originalCond, expectedCond = FalseLiteral)
213226
testDelete(originalCond, expectedCond = FalseLiteral)
214227
testUpdate(originalCond, expectedCond = FalseLiteral)
228+
testMerge(originalCond, expectedCond = FalseLiteral)
215229
}
216230

217231
test("inability to replace null in non-boolean branches of If") {
@@ -226,6 +240,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
226240
testJoin(originalCond = condition, expectedCond = condition)
227241
testDelete(originalCond = condition, expectedCond = condition)
228242
testUpdate(originalCond = condition, expectedCond = condition)
243+
testMerge(originalCond = condition, expectedCond = condition)
229244
}
230245

231246
test("inability to replace null in non-boolean values of CaseWhen") {
@@ -395,6 +410,21 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
395410
test((rel, expr) => UpdateTable(rel, Seq.empty, Some(expr)), originalCond, expectedCond)
396411
}
397412

413+
private def testMerge(originalCond: Expression, expectedCond: Expression): Unit = {
414+
val func = (rel: LogicalPlan, expr: Expression) => {
415+
val assignments = Seq(
416+
Assignment('i, 'i),
417+
Assignment('b, 'b),
418+
Assignment('a, 'a),
419+
Assignment('m, 'm)
420+
)
421+
val matchedActions = UpdateAction(Some(expr), assignments) :: DeleteAction(Some(expr)) :: Nil
422+
val notMatchedActions = InsertAction(Some(expr), assignments) :: Nil
423+
MergeIntoTable(rel, rel, mergeCondition = expr, matchedActions, notMatchedActions)
424+
}
425+
test(func, originalCond, expectedCond)
426+
}
427+
398428
private def testHigherOrderFunc(
399429
argument: Expression,
400430
createExpr: (Expression, Expression) => Expression,

0 commit comments

Comments
 (0)