Skip to content

Commit 1beb40c

Browse files
committed
address comments
1 parent 65fca4f commit 1beb40c

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,20 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
114114
// Deduplicate conflicting attributes if any.
115115
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
116116
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
117+
// Deduplicate conflicting attributes if any.
117118
val newSub = dedupSubqueryOnSelfJoin(values, sub)
118119
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
119120
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
120-
// Deduplicate conflicting attributes if any.
121-
dedupJoin(Join(outerPlan, newSub, LeftSemi, joinCond))
121+
Join(outerPlan, newSub, LeftSemi, joinCond)
122122
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
123123
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
124124
// Construct the condition. A NULL in one of the conditions is regarded as a positive
125125
// result; such a row will be filtered out by the Anti-Join operator.
126126

127127
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
128128
// Use EXISTS if performance matters to you.
129+
130+
// Deduplicate conflicting attributes if any.
129131
val newSub = dedupSubqueryOnSelfJoin(values, sub)
130132
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
131133
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
@@ -142,8 +144,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
142144
// will have the final conditions in the LEFT ANTI as
143145
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
144146
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
145-
// Deduplicate conflicting attributes if any.
146-
dedupJoin(Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond)))
147+
Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
147148
case (p, predicate) =>
148149
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
149150
Project(p.output, Filter(newCond.get, inputPlan))
@@ -170,11 +171,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
170171
exists
171172
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
172173
val exists = AttributeReference("exists", BooleanType, nullable = false)()
174+
// Deduplicate conflicting attributes if any.
173175
val newSub = dedupSubqueryOnSelfJoin(values, sub)
174176
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
175177
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
176-
// Deduplicate conflicting attributes if any.
177-
newPlan = dedupJoin(Join(newPlan, newSub, ExistenceJoin(exists), newConditions))
178+
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
178179
exists
179180
}
180181
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,14 +1284,8 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
12841284

12851285
test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
12861286
withTempView("a", "b") {
1287-
def genTestViewWithName(name: String): Unit = {
1288-
val df = spark.createDataFrame(
1289-
spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))),
1290-
StructType(Seq(StructField("id", StringType), StructField("num", IntegerType))))
1291-
df.createOrReplaceTempView(name)
1292-
}
1293-
genTestViewWithName("a")
1294-
genTestViewWithName("b")
1287+
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
1288+
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
12951289

12961290
val df1 = spark.sql(
12971291
"""

0 commit comments

Comments
 (0)