Skip to content

Commit 64e3320

Browse files
committed
Infer additional constraints from attribute equality
1 parent dbf2a7c commit 64e3320

2 files changed

Lines changed: 34 additions & 0 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
3232
*/
3333
protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
3434
constraints
35+
.union(inferAdditionalConstraints(constraints))
3536
.union(constructIsNotNullConstraints(constraints))
3637
.filter(constraint =>
3738
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet))
@@ -61,6 +62,25 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
6162
}.foldLeft(Set.empty[Expression])(_ union _.toSet)
6263
}
6364

65+
/**
66+
* Infers an additional set of constraints from a given set of equality constraints.
67+
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
68+
* additional constraint of the form `b = 5`
69+
*/
70+
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
71+
constraints.map {
72+
case eq @ EqualTo(l: Attribute, r: Attribute) =>
73+
(constraints -- Set(eq)).map(_ transform {
74+
case a: Attribute if a.semanticEquals(l) => r
75+
}).union(
76+
(constraints -- Set(eq)).map(_ transform {
77+
case a: Attribute if a.semanticEquals(r) => l
78+
}))
79+
case _ =>
80+
Set.empty[Expression]
81+
}.foldLeft(Set.empty[Expression])(_ union _) -- constraints
82+
}
83+
6484
/**
6585
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
6686
* example, if this set contains the expression `a = 2` then that expression is guaranteed to

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
158158
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
159159
tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
160160
tr2.resolveQuoted("a", caseInsensitiveResolution).get,
161+
tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10,
161162
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
162163
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
163164
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
@@ -203,4 +204,17 @@ class ConstraintPropagationSuite extends SparkFunSuite {
203204
.join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr))
204205
.analyze.constraints.isEmpty)
205206
}
207+
208+
test("infer additional constraints in filters") {
209+
val tr = LocalRelation('a.int, 'b.int, 'c.int)
210+
211+
verifyConstraints(tr
212+
.where('a.attr > 10 && 'a.attr === 'b.attr)
213+
.analyze.constraints,
214+
ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
215+
resolveColumn(tr, "b") > 10,
216+
resolveColumn(tr, "a") === resolveColumn(tr, "b"),
217+
IsNotNull(resolveColumn(tr, "a")),
218+
IsNotNull(resolveColumn(tr, "b")))))
219+
}
206220
}

0 commit comments

Comments
 (0)