Skip to content

Commit 2b7bfba

Browse files
committed
[SPARK-26078][SQL][BACKPORT-2.4] Dedup self-join attributes on IN subqueries
When there is a self-join as result of a IN subquery, the join condition may be invalid, resulting in trivially true predicates and return wrong results. The PR deduplicates the subquery output in order to avoid the issue. added UT Closes apache#23057 from mgaido91/SPARK-26078. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent c0f4082 commit 2b7bfba

2 files changed

Lines changed: 97 additions & 38 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
4343
* condition.
4444
*/
4545
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
46-
private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
46+
47+
private def buildJoin(
48+
outerPlan: LogicalPlan,
49+
subplan: LogicalPlan,
50+
joinType: JoinType,
51+
condition: Option[Expression]): Join = {
52+
// Deduplicate conflicting attributes if any.
53+
val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition)
54+
Join(outerPlan, dedupSubplan, joinType, condition)
55+
}
56+
57+
private def dedupSubqueryOnSelfJoin(
58+
outerPlan: LogicalPlan,
59+
subplan: LogicalPlan,
60+
valuesOpt: Option[Seq[Expression]],
61+
condition: Option[Expression] = None): LogicalPlan = {
4762
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
4863
// the produced join then becomes unresolved and break structural integrity. We should
49-
// de-duplicate conflicting attributes. We don't use transformation here because we only
50-
// care about the most top join converted from correlated predicate subquery.
51-
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
52-
val duplicates = right.outputSet.intersect(left.outputSet)
53-
if (duplicates.nonEmpty) {
54-
val aliasMap = AttributeMap(duplicates.map { dup =>
55-
dup -> Alias(dup, dup.toString)()
56-
}.toSeq)
57-
val aliasedExpressions = right.output.map { ref =>
58-
aliasMap.getOrElse(ref, ref)
59-
}
60-
val newRight = Project(aliasedExpressions, right)
61-
val newJoinCond = joinCond.map { condExpr =>
62-
condExpr transform {
63-
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
64+
// de-duplicate conflicting attributes.
65+
// SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer
66+
// values. In this case, the resulting join would contain trivially true conditions (eg.
67+
// id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting
68+
// attributes in the join condition, the subquery's conflicting attributes are changed using
69+
// a projection which aliases them and resolves the problem.
70+
val outerReferences = valuesOpt.map(values =>
71+
AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty)
72+
val outerRefs = outerPlan.outputSet ++ outerReferences
73+
val duplicates = outerRefs.intersect(subplan.outputSet)
74+
if (duplicates.nonEmpty) {
75+
condition.foreach { e =>
76+
val conflictingAttrs = e.references.intersect(duplicates)
77+
if (conflictingAttrs.nonEmpty) {
78+
throw new AnalysisException("Found conflicting attributes " +
79+
s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " +
80+
s"$outerPlan\nand subplan:\n $subplan")
6481
}
65-
}
66-
Join(left, newRight, joinType, newJoinCond)
67-
} else {
68-
j
6982
}
70-
case _ => joinPlan
83+
val rewrites = AttributeMap(duplicates.map { dup =>
84+
dup -> Alias(dup, dup.toString)()
85+
}.toSeq)
86+
val aliasedExpressions = subplan.output.map { ref =>
87+
rewrites.getOrElse(ref, ref)
88+
}
89+
Project(aliasedExpressions, subplan)
90+
} else {
91+
subplan
92+
}
7193
}
7294

7395
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -85,25 +107,27 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
85107
withSubquery.foldLeft(newFilter) {
86108
case (p, Exists(sub, conditions, _)) =>
87109
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
88-
// Deduplicate conflicting attributes if any.
89-
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
110+
buildJoin(outerPlan, sub, LeftSemi, joinCond)
90111
case (p, Not(Exists(sub, conditions, _))) =>
91112
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
92-
// Deduplicate conflicting attributes if any.
93-
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
113+
buildJoin(outerPlan, sub, LeftAnti, joinCond)
94114
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
95-
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
96-
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
97115
// Deduplicate conflicting attributes if any.
98-
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
116+
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
117+
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
118+
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
119+
Join(outerPlan, newSub, LeftSemi, joinCond)
99120
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
100121
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
101122
// Construct the condition. A NULL in one of the conditions is regarded as a positive
102123
// result; such a row will be filtered out by the Anti-Join operator.
103124

104125
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
105126
// Use EXISTS if performance matters to you.
106-
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
127+
128+
// Deduplicate conflicting attributes if any.
129+
val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
130+
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
107131
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
108132
// Expand the NOT IN expression with the NULL-aware semantic
109133
// to its full form. That is from:
@@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
118142
// will have the final conditions in the LEFT ANTI as
119143
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
120144
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
121-
// Deduplicate conflicting attributes if any.
122-
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
145+
Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
123146
case (p, predicate) =>
124147
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
125148
Project(p.output, Filter(newCond.get, inputPlan))
@@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
140163
e transformUp {
141164
case Exists(sub, conditions, _) =>
142165
val exists = AttributeReference("exists", BooleanType, nullable = false)()
143-
// Deduplicate conflicting attributes if any.
144-
newPlan = dedupJoin(
145-
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
166+
newPlan =
167+
buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
146168
exists
147169
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
148170
val exists = AttributeReference("exists", BooleanType, nullable = false)()
149-
val inConditions = values.zip(sub.output).map(EqualTo.tupled)
150-
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
151171
// Deduplicate conflicting attributes if any.
152-
newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
172+
val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
173+
val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
174+
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
175+
newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
153176
exists
154177
}
155178
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,4 +1268,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
12681268
assert(getNumSortsInQuery(query5) == 1)
12691269
}
12701270
}
1271+
1272+
test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
1273+
withTempView("a", "b") {
1274+
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
1275+
Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
1276+
1277+
val df1 = spark.sql(
1278+
"""
1279+
|SELECT id,num,source FROM (
1280+
| SELECT id, num, 'a' as source FROM a
1281+
| UNION ALL
1282+
| SELECT id, num, 'b' as source FROM b
1283+
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
1284+
""".stripMargin)
1285+
checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
1286+
val df2 = spark.sql(
1287+
"""
1288+
|SELECT id,num,source FROM (
1289+
| SELECT id, num, 'a' as source FROM a
1290+
| UNION ALL
1291+
| SELECT id, num, 'b' as source FROM b
1292+
|) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
1293+
""".stripMargin)
1294+
checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
1295+
val df3 = spark.sql(
1296+
"""
1297+
|SELECT id,num,source FROM (
1298+
| SELECT id, num, 'a' as source FROM a
1299+
| UNION ALL
1300+
| SELECT id, num, 'b' as source FROM b
1301+
|) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
1302+
|c.id IN (SELECT id FROM b WHERE num = 3)
1303+
""".stripMargin)
1304+
checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
1305+
}
1306+
}
12711307
}

0 commit comments

Comments
 (0)