Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
TransposeWindow,
NullPropagation,
ConstantPropagation,
FilterReduction,
FoldablePropagation,
OptimizeIn,
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}
import scala.collection.mutable.{ArrayBuffer, Map, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -55,99 +55,114 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

/**
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
* Substitutes [[Expression Expressions]] which can be statically evaluated with their corresponding
* value in conjunctive [[Expression Expressions]]
* eg.
* {{{
* SELECT * FROM table WHERE i = 5 AND j = i + 3
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
* SELECT * FROM table WHERE i = 5 AND j = i + 3 => ... WHERE i = 5 AND j = 8
* SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE abs(i) = 5 AND j <= 8
* }}}
*
* Approach used:
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
* - Populate a mapping of expression => constant value by looking at all the deterministic equals
* predicates
* - Using this mapping, replace occurrence of the expressions with the corresponding constant
* values in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since new PR is created, we had better remove this change from this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, reverted.
Please note that some of the new UTs will require the enhanced constant propagation to work as expected.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
val (newCondition, _) = traverse(f.condition)
if (newCondition fastEquals f.condition) {
f
} else {
f.copy(condition = newCondition)
}
}

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]

/**
* Traverse a condition as a tree and replace attributes with constant values.
* Traverse a condition as a tree and replace expressions with constant values.
* - On matching [[And]], recursively traverse each children and get propagated mappings.
* If the current node is not child of another [[And]], replace all occurrences of the
* attributes with the corresponding constant values.
* expressions with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
* of attribute => constant.
* of expression => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant values in children
* @param expression expression to be traversed
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
* 2. Map[Expression, Literal]: propagated mapping of expression => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean)
: (Option[Expression], EqualityPredicates) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
(None, Seq(((left, right), e)))
case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
private def traverse(expression: Expression): (Expression, Map[Expression, Literal]) =
expression match {
case e @ EqualTo(left, right: Literal) if !left.foldable && left.deterministic =>
(e, Map(left.canonicalized -> right))
case e @ EqualTo(left: Literal, right) if !right.foldable && right.deterministic =>
(e, Map(right.canonicalized -> left))
case e @ EqualNullSafe(left, right: Literal) if !left.foldable && left.deterministic =>
(e, Map(left.canonicalized -> right))
case e @ EqualNullSafe(left: Literal, right) if !right.foldable && right.deterministic =>
(e, Map(right.canonicalized -> left))
case a @ And(left, right) =>
val (newLeft, equalityPredicatesLeft) = traverse(left)
val replacedRight = replaceConstants(right, equalityPredicatesLeft)
val (replacedNewRight, equalityPredicatesRight) = traverse(replacedRight)
val replacedNewLeft = replaceConstants(newLeft, equalityPredicatesRight)
val newAnd = if ((replacedNewLeft fastEquals left) && (replacedNewRight fastEquals right)) {
a
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
And(replacedNewLeft, replacedNewRight)
}
(newSelf, equalityPredicates)
case o: Or =>
(newAnd, equalityPredicatesLeft ++= equalityPredicatesRight)
case o @ Or(left, right) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true)
val (newRight, _) = traverse(o.right, replaceChildren = true)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
val (newLeft, _) = traverse(left)
val (newRight, _) = traverse(right)
val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) {
o
} else {
None
Or(newLeft, newRight)
}
(newSelf, Seq.empty)
case n: Not =>

(newOr, Map.empty)
case n @ Not(child) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
val (newChild, _) = traverse(child)
val newNot = if (newChild fastEquals child) {
n
} else {
Not(newChild)
}
(newNot, Map.empty)
case _ => (expression, Map.empty)
}

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
private def replaceConstants(
expression: Expression,
constants: Map[Expression, Literal]) =
expression transformUp {
case e if constants.contains(e.canonicalized) => constants(e.canonicalized)
}
}

/**
* Substitutes expressions which can be statically reduced by constraints.
* eg.
* {{{
* SELECT * FROM table WHERE i <= 5 AND i = 5 => ... WHERE i = 5
* SELECT * FROM table WHERE i < j AND ... AND i > j => ... WHERE false
* }}}
*/
object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
val newCondition = normalizeAndReduceWithConstraints(f.condition)
if (newCondition fastEquals f.condition) {
f
} else {
f.copy(condition = newCondition)
}
}
}

Expand Down
Loading