Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -782,26 +782,65 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftEvaluateCondition, rightEvaluateCondition, commonCondition)
}

/**
* Infer the condition using equalTo predicate transitivity
*/
private def inferConditions (
condition: Seq[Expression],
attrSet: AttributeSet,
commonCondition: Seq[Expression]) = {
val attrMap = AttributeMap(commonCondition.collect {
case EqualTo(l: AttributeReference, r: AttributeReference) if attrSet.contains(l) =>
(l.toAttribute, r)
case EqualTo(l: AttributeReference, r: AttributeReference) =>
(r.toAttribute, l)
})
condition.collect {
case EqualTo(ar: AttributeReference, l: Literal) if attrMap.contains(ar) =>
EqualTo(attrMap(ar), l)
case GreaterThan(ar: AttributeReference, l: Literal) if attrMap.contains(ar) =>
GreaterThan(attrMap(ar), l)
case GreaterThanOrEqual(ar: AttributeReference, l: Literal) if attrMap.contains(ar) =>
GreaterThanOrEqual(attrMap(ar), l)
case LessThan(ar: AttributeReference, l: Literal) if attrMap.contains(ar) =>
LessThan(attrMap(ar), l)
case LessThanOrEqual(ar: AttributeReference, l: Literal) if attrMap.contains(ar) =>
LessThanOrEqual(attrMap(ar), l)
case In(ar: AttributeReference, l) if attrMap.contains(ar) =>
In(attrMap(ar), l)
case IsNull(ar: AttributeReference) if attrMap.contains(ar) =>
IsNull(attrMap(ar))
case IsNotNull(ar: AttributeReference) if attrMap.contains(ar) =>
IsNotNull(attrMap(ar))
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// push the where condition down into join filter
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)
val leftInferredFilterConditions =
(leftFilterConditions ++
inferConditions(rightFilterConditions, right.outputSet, commonFilterCondition)).distinct
val rightInferredFilterConditions =
(rightFilterConditions ++
inferConditions(leftFilterConditions, left.outputSet, commonFilterCondition)).distinct

joinType match {
case Inner =>
// push down the single side `where` condition into respective sides
val newLeft = leftFilterConditions.
val newLeft = leftInferredFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightFilterConditions.
val newRight = rightInferredFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)

Join(newLeft, newRight, Inner, newJoinCond)
case RightOuter =>
// push down the right side only `where` condition
val newLeft = left
val newRight = rightFilterConditions.
val newRight = rightInferredFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
Expand All @@ -810,7 +849,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case _ @ (LeftOuter | LeftSemi) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
val newLeft = leftInferredFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = joinCondition
Expand All @@ -825,20 +864,26 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
case f @ Join(left, right, joinType, joinCondition) =>
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
val leftInferredJoinConditions =
(leftJoinConditions ++
inferConditions(rightJoinConditions, right.outputSet, commonJoinCondition)).distinct
val rightInferredJoinConditions =
(rightJoinConditions ++
inferConditions(leftJoinConditions, left.outputSet, commonJoinCondition)).distinct

joinType match {
case _ @ (Inner | LeftSemi) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
val newLeft = leftInferredJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = rightJoinConditions.
val newRight = rightInferredJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = commonJoinCondition.reduceLeftOption(And)

Join(newLeft, newRight, joinType, newJoinCond)
case RightOuter =>
// push down the left side only join filter for left side sub query
val newLeft = leftJoinConditions.
val newLeft = leftInferredJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
Expand All @@ -847,7 +892,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
case LeftOuter =>
// push down the right side only join filter for right sub query
val newLeft = left
val newRight = rightJoinConditions.
val newRight = rightInferredJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class FilterPushdownSuite extends PlanTest {
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b >= 1)
val left = testRelation.where('b >= 1 && 'a >= 2)
val right = testRelation1.where('d >= 2)
val correctAnswer =
left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze
Expand Down Expand Up @@ -537,8 +537,8 @@ class FilterPushdownSuite extends PlanTest {

val optimized = Optimize.execute(originalQuery.analyze)
val lleft = testRelation.where('a >= 3).subquery('z)
val left = testRelation.where('a === 1).subquery('x)
val right = testRelation.subquery('y)
val left = testRelation.where('a === 1 && 'b >= 3).subquery('x)
val right = testRelation.where('b >= 3).subquery('y)
val correctAnswer =
lleft.join(
left.join(right, condition = Some("x.b".attr === "y.b".attr)),
Expand Down Expand Up @@ -750,4 +750,56 @@ class FilterPushdownSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("joins: push to both sides after predicate transitivity") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

val originalQuery = {
x.join(y)
.where("x.a".attr === 1 && "y.d".attr === "x.a".attr && 'd.isNotNull)
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a === 1 && 'a.isNotNull)
val right = testRelation1.where('d.isNotNull && 'd === 1)
val correctAnswer =
left.join(right, condition = Some("d".attr === "a".attr)).analyze

comparePlans(optimized, correctAnswer)
}

test("joins: push down left outer join after predicate transitivity ") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

val originalQuery = {
x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr && "y.d".attr >= 2))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation
val right = testRelation1.where('d >= 2)
val correctAnswer =
left.join(right, LeftOuter, Option("a".attr === "d".attr)).analyze

comparePlans(optimized, correctAnswer)
}

test("joins: push down right outer join after predicate transitivity ") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

val originalQuery = {
x.join(y, RightOuter, Option("x.a".attr === "y.d".attr && "y.d".attr >= 2))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a >= 2)
val right = testRelation1
val correctAnswer =
left.join(right, RightOuter, Option("d".attr >= 2 && "a".attr === "d".attr)).analyze

comparePlans(optimized, correctAnswer)
}
}