Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import org.apache.spark.sql.catalyst.rules._
class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
val batches = Batch("InferFilters", FixedPoint(100), InferFiltersFromConstraints) ::
Batch("PredicatePushdown", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate) ::
Batch("CombineFilters", FixedPoint(100), CombineFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down Expand Up @@ -120,4 +122,28 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("don't generate constraints for recursive functions") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)
val t3 = testRelation.subquery('t3)

val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2, Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1.where(IsNotNull('a) && 'a === Coalesce(Seq('a, 'b))
&& IsNotNull('b) && 'b === Coalesce(Seq('a, 'b))
&& IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These predicates are infered from t.a = t2.a, t.d = t2.a, t.int_col = t2.a, which in line with our expectation.

.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._

/**
Expand Down Expand Up @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
* etc., will all now be equivalent.
* - Sample the seed will replaced by 0L.
* - Join conditions will be resorted by hashCode.
*/
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)(true)
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And)
Join(left, right, joinType, Some(newCondition))
}
}

/**
* Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
* equivalent:
* 1. (a = b), (b = a);
* 2. (a <=> b), (b <=> a).
*/
private def rewriteEqual(condition: Expression): Expression = condition match {
case eq @ EqualTo(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
case eq @ EqualNullSafe(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
case _ => condition // Don't reorder.
}

/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
val normalized1 = normalizePlan(normalizeExprIds(plan1))
Expand Down