Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,36 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* additional constraint of the form `b = 5`
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
// Collect alias from expressions to avoid producing non-converging set of constraints
// for recursive functions.
// For more details, infer https://issues.apache.org/jira/browse/SPARK-17733
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "infer" -> "refer" (to)?

val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using AttributeSet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, AttributeSet is a better choice here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since aliasMap is referenced at a number of places, let's just make this a private lazy val and move it outside of this method in QueryPlan.

case a: Alias => (a.toAttribute, a.child)
})

var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
inferredConstraints ++= (constraints - eq).map(_ transform {
case a: Attribute if a.semanticEquals(l) => r
case a: Attribute if a.semanticEquals(l) && !isRecursiveDeduction(a, r, aliasMap) => r
})
inferredConstraints ++= (constraints - eq).map(_ transform {
case a: Attribute if a.semanticEquals(r) => l
case a: Attribute if a.semanticEquals(r) && !isRecursiveDeduction(l, a, aliasMap) => l
})
case _ => // No inference
}
inferredConstraints -- constraints
}

private def isRecursiveDeduction(
left: Attribute,
right: Attribute,
aliasMap: AttributeMap[Expression]): Boolean = {
val leftExpression = aliasMap.getOrElse(left, left)
val rightExpression = aliasMap.getOrElse(right, right)
leftExpression.containsChild(rightExpression) || rightExpression.containsChild(leftExpression)
}

/**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import org.apache.spark.sql.catalyst.rules._
class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
val batches = Batch("InferFilters", FixedPoint(100), InferFiltersFromConstraints) ::
Batch("PredicatePushdown", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate) ::
Batch("CombineFilters", FixedPoint(100), CombineFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down Expand Up @@ -120,4 +122,28 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("don't generate constraints for recursive functions") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)
val t3 = testRelation.subquery('t3)

val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2, Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1.where(IsNotNull('a) && 'a === Coalesce(Seq('a, 'b))
&& IsNotNull('b) && 'b === Coalesce(Seq('a, 'b))
&& IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These predicates are infered from t.a = t2.a, t.d = t2.a, t.int_col = t2.a, which in line with our expectation.

.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._

/**
Expand Down Expand Up @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
* etc., will all now be equivalent.
* - Sample the seed will replaced by 0L.
* - Join conditions will be resorted by hashCode.
*/
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)(true)
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And)
Join(left, right, joinType, Some(newCondition))
}
}

/**
* Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
* equivalent:
* 1. (a = b), (b = a);
* 2. (a <=> b), (b <=> a).
*/
private def rewriteEqual(condition: Expression): Expression = condition match {
case eq @ EqualTo(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
case eq @ EqualNullSafe(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
case _ => condition // Don't reorder.
}

/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
val normalized1 = normalizePlan(normalizeExprIds(plan1))
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2678,4 +2678,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we construct a unit test rather than an end-to-end test here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - Perhaps we could add new testcases in InferFiltersFromConstraintsSuite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you already have a unit test for cases like these, how about we remove this now? This test was randomly generated to catch issues like this and in its current form, it isn't very obvious how this query has anything to do with InferFiltersFromConstraints.

withTempView("tmpv") {
spark.range(10).toDF("a").createTempView("tmpv")

// Just ensure the following query will successfully execute complete.
assert(sql(
"""
|SELECT
| *
|FROM (
| SELECT
| COALESCE(t1.a, t2.a) AS int_col,
| t1.a,
| t2.a AS b
| FROM tmpv t1
| CROSS JOIN tmpv t2
|) t1
|INNER JOIN tmpv t2
|ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b))
""".stripMargin).count() > 0
)
}
}
}