Skip to content

Commit 7d9e2b0

Browse files
committed
add new testcase.
1 parent ebba446 commit 7d9e2b0

2 files changed

Lines changed: 52 additions & 5 deletions

File tree

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ import org.apache.spark.sql.catalyst.rules._
2727
class InferFiltersFromConstraintsSuite extends PlanTest {
2828

2929
object Optimize extends RuleExecutor[LogicalPlan] {
30-
val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
31-
Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
32-
Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
30+
val batches = Batch("InferFilters", FixedPoint(100), InferFiltersFromConstraints) ::
31+
Batch("PredicatePushdown", FixedPoint(100),
32+
PushPredicateThroughJoin,
33+
PushDownPredicate) ::
34+
Batch("CombineFilters", FixedPoint(100), CombineFilters) :: Nil
3335
}
3436

3537
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -120,4 +122,28 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
120122
val optimized = Optimize.execute(originalQuery)
121123
comparePlans(optimized, correctAnswer)
122124
}
125+
126+
test("don't generate constraints for recursive functions") {
127+
val t1 = testRelation.subquery('t1)
128+
val t2 = testRelation.subquery('t2)
129+
val t3 = testRelation.subquery('t3)
130+
131+
val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
132+
.join(t2, Inner,
133+
Some("t.a".attr === "t2.a".attr
134+
&& "t.d".attr === "t2.a".attr
135+
&& "t.int_col".attr === "t2.a".attr))
136+
.analyze
137+
val correctAnswer = t1.where(IsNotNull('a) && 'a === Coalesce(Seq('a, 'b))
138+
&& IsNotNull('b) && 'b === Coalesce(Seq('a, 'b))
139+
&& IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b)
140+
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
141+
.join(t2.where(IsNotNull('a)), Inner,
142+
Some("t.a".attr === "t2.a".attr
143+
&& "t.d".attr === "t2.a".attr
144+
&& "t.int_col".attr === "t2.a".attr))
145+
.analyze
146+
val optimized = Optimize.execute(originalQuery)
147+
comparePlans(optimized, correctAnswer)
148+
}
123149
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23-
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
23+
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.util._
2525

2626
/**
@@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
5656
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
5757
* etc., will all now be equivalent.
5858
* - Sample the seed will replaced by 0L.
59+
* - Join conditions will be resorted by hashCode.
5960
*/
6061
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
6162
plan transform {
6263
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
63-
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
64+
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
65+
.reduce(And), child)
6466
case sample: Sample =>
6567
sample.copy(seed = 0L)(true)
68+
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
69+
val newCondition =
70+
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
71+
.reduce(And)
72+
Join(left, right, joinType, Some(newCondition))
6673
}
6774
}
6875

76+
/**
77+
* Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
78+
* equivalent:
79+
* 1. (a = b), (b = a);
80+
* 2. (a <=> b), (b <=> a).
81+
*/
82+
private def rewriteEqual(condition: Expression): Expression = condition match {
83+
case eq @ EqualTo(l: Expression, r: Expression) =>
84+
Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
85+
case eq @ EqualNullSafe(l: Expression, r: Expression) =>
86+
Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
87+
case _ => condition // Don't reorder.
88+
}
89+
6990
/** Fails the test if the two plans do not match */
7091
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
7192
val normalized1 = normalizePlan(normalizeExprIds(plan1))

0 commit comments

Comments
 (0)