Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,26 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* additional constraint of the form `b = 5`
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
// Collect alias from expressions to avoid producing non-converging set of constraints
// for recursive functions.
//
// Don't apply transform on constraints if the attribute used to replace is an alias,
// because then both `QueryPlan.inferAdditionalConstraints` and
// `UnaryNode.getAliasedConstraints` applies and may produce a non-converging set of
// constraints.
// For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't seem to give a lot of context about the underlying issue. How about we just add a top-level comment for this method summarizing the issue and remove this? Perhaps, something along the lines of the following:

  /**
   * Infers an additional set of constraints from a given set of equality constraints.
   * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
   * additional constraint of the form `b = 5`.
   *
   * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
   * as they are often useless and can lead to a non-converging set of constraints.
   */
  private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression]

val aliasSet = AttributeSet((expressions ++ children.flatMap(_.expressions)).collect {
case a: Alias => a.toAttribute
})

var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
inferredConstraints ++= (constraints - eq).map(_ transform {
case a: Attribute if a.semanticEquals(l) => r
case a: Attribute if a.semanticEquals(l) && !aliasSet.contains(r) => r
Copy link
Member

@sameeragarwal sameeragarwal Oct 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiangxb1987 isn't this a fairly restrictive way to solve this problem? You are essentially not inferring any additional constraints from those that contain aliases. For e.g., if we have a subquery SELECT a AS a1, b AS b1 WHERE a1 = 1 AND a1 = b1, this change would never allow us to infer a filter/constraint on b = 1. Can we identify and just disallow recursive constraints?

Copy link
Contributor Author

@jiangxb1987 jiangxb1987 Oct 18, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you a lot for your advice! Perhaps We should generate the following two sets:

  1. Set of Alias which have references to other expressions, for instance, Alias(f(b, c), "a"), or Alias(a, "a1");
  2. Generate sets of equivalence classes out of EqualTo operators in constraints, e.g., when we have a = b and c = b and e = f, then the sets would be ((a, b, c), (e, f)).
    Here, for any expressions to be used to infer new constraints, we should check that either it's not in our AliasSet, or its reference doesn't contain any expressions in the corresponding equivalence classes set.

I'll update this check rule ASAP. Thank you for helping!

})
inferredConstraints ++= (constraints - eq).map(_ transform {
case a: Attribute if a.semanticEquals(r) => l
case a: Attribute if a.semanticEquals(r) && !aliasSet.contains(l) => l
})
case _ => // No inference
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._
class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
val batches =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous batches will not apply InferFiltersFromConstraints after PushPredicateThroughJoin.

Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints,
CombineFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down Expand Up @@ -120,4 +123,64 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("inner join with alias: alias contains multiple attributes") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val currectAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo correctAnswer

&&'a === Coalesce(Seq('a, 'b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2 spaces

.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, currectAnswer)
}

test("inner join with alias: alias contains single attributes") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

val originalQuery = t1.select('a, 'b.as('d)).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val currectAnswer = t1.where(IsNotNull('a) && IsNotNull('b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo

&& 'a <=> 'a && 'b <=> 'b &&'a === 'b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2 spaces

.select('a, 'b.as('d)).as("t")
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, currectAnswer)
}

test("inner join with alias: don't generate constraints for recursive functions") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2, Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
&& 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._

/**
Expand Down Expand Up @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
* etc., will all now be equivalent.
* - Sample the seed will replaced by 0L.
* - Join conditions will be resorted by hashCode.
*/
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)(true)
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And)
Join(left, right, joinType, Some(newCondition))
}
}

/**
* Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
* equivalent:
* 1. (a = b), (b = a);
* 2. (a <=> b), (b <=> a).
*/
private def rewriteEqual(condition: Expression): Expression = condition match {
case eq @ EqualTo(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
case eq @ EqualNullSafe(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
case _ => condition // Don't reorder.
}

/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
val normalized1 = normalizePlan(normalizeExprIds(plan1))
Expand Down
36 changes: 32 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.spark.sql

import java.io.File
import java.math.MathContext
import java.sql.{Date, Timestamp}
import java.sql.Timestamp

import scala.concurrent.duration._

import org.scalatest.concurrent.Eventually._

import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
Expand Down Expand Up @@ -2678,4 +2679,31 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we construct a unit test rather than an end-to-end test here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - Perhaps we could add new testcases in InferFiltersFromConstraintsSuite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you already have a unit test for cases like these, how about we remove this now? This test was randomly generated to catch issues like this and in its current form, it isn't very obvious how this query has anything to do with InferFiltersFromConstraints.

withTempView("tmpv") {
spark.range(10).toDF("a").createTempView("tmpv")

// Just ensure the following query will successfully execute complete.
val query =
"""
|SELECT
| *
|FROM (
| SELECT
| COALESCE(t1.a, t2.a) AS int_col,
| t1.a,
| t2.a AS b
| FROM tmpv t1
| CROSS JOIN tmpv t2
|) t1
|INNER JOIN tmpv t2
|ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b))
""".stripMargin

eventually(timeout(60 seconds)) {
assert(sql(query).count() > 0)
}
}
}
}