Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
TransposeWindow,
NullPropagation,
ConstantPropagation,
FilterReduction,
FoldablePropagation,
OptimizeIn,
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,99 +55,108 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

/**
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
* Substitutes [[Expression Expressions]] which can be statically evaluated with their corresponding
* value in conjunctive [[Expression Expressions]]
* eg.
* {{{
* SELECT * FROM table WHERE i = 5 AND j = i + 3
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
* SELECT * FROM table WHERE i = 5 AND j = i + 3 => ... WHERE i = 5 AND j = 8
* SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE abs(i) = 5 AND j <= 8
* }}}
*
* Approach used:
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
* - Populate a mapping of expression => constant value by looking at all the deterministic equals
* predicates
* - Using this mapping, replace occurrence of the expressions with the corresponding constant
* values in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since new PR is created, we had better remove this change from this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, reverted.
Please note that some of the new UTs will require the enhanced constant propagation to work as expected.

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
val (newCondition, _) = traverse(f.condition)
if (newCondition fastEquals f.condition) {
f
} else {
f.copy(condition = newCondition)
}
}

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]

/**
* Traverse a condition as a tree and replace attributes with constant values.
* Traverse a condition as a tree and replace expressions with constant values.
* - On matching [[And]], recursively traverse each children and get propagated mappings.
* If the current node is not child of another [[And]], replace all occurrences of the
* attributes with the corresponding constant values.
* expressions with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
* of attribute => constant.
* of expression => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant values in children
* @param expression expression to be traversed
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
* 2. Seq[(Expression, Literal)]: propagated mapping of expression => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean)
: (Option[Expression], EqualityPredicates) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
(None, Seq(((left, right), e)))
case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
private def traverse(expression: Expression): (Expression, Seq[(Expression, Literal)]) =
expression match {
case e @ EqualTo(left, right: Literal) if e.deterministic => (e, Seq((left, right)))
case e @ EqualTo(left: Literal, right) if e.deterministic => (e, Seq((right, left)))
case e @ EqualNullSafe(left, right: Literal) if e.deterministic => (e, Seq((left, right)))
case e @ EqualNullSafe(left: Literal, right) if e.deterministic => (e, Seq((right, left)))
case a @ And(left, right) =>
val (newLeft, equalityPredicatesLeft) = traverse(left)
val replacedRight = replaceConstants(right, equalityPredicatesLeft)
val (replacedNewRight, equalityPredicatesRight) = traverse(replacedRight)
val replacedNewLeft = replaceConstants(newLeft, equalityPredicatesRight)
val newAnd = if ((replacedNewLeft fastEquals left) && (replacedNewRight fastEquals right)) {
a
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
And(replacedNewLeft, replacedNewRight)
}
(newSelf, equalityPredicates)
case o: Or =>
(newAnd, equalityPredicatesLeft ++ equalityPredicatesRight)
case o @ Or(left, right) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true)
val (newRight, _) = traverse(o.right, replaceChildren = true)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
val (newLeft, _) = traverse(left)
val (newRight, _) = traverse(right)
val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) {
o
} else {
None
Or(newLeft, newRight)
}
(newSelf, Seq.empty)
case n: Not =>

(newOr, Seq.empty)
case n @ Not(child) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
val (newChild, _) = traverse(child)
val newNot = if (newChild fastEquals child) {
n
} else {
Not(newChild)
}
(newNot, Seq.empty)
case _ => (expression, Seq.empty)
}

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
}
private def replaceConstants(expression: Expression, constants: Seq[(Expression, Literal)]) =
constants.foldLeft(expression)((e, constant) => e transformUp {
case e if e.canonicalized == constant._1.canonicalized => constant._2
})
}

/**
* Substitutes expressions which can be statically reduced by constraints.
* eg.
* {{{
* SELECT * FROM table WHERE i <= 5 AND i = 5 => ... WHERE i = 5
* SELECT * FROM table WHERE i < j AND ... AND i > j => ... WHERE false
* }}}
*/
object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
val newCondition = normalizeAndReduceWithConstraints(f.condition)
if (newCondition fastEquals f.condition) {
f
} else {
f.copy(condition = newCondition)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,193 @@ trait ConstraintHelper {
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

def normalizeAndReduceWithConstraints(expression: Expression): Expression =
reduceWithConstraints(normalize(expression))._1

private def normalize(expression: Expression) = expression transform {
case GreaterThan(x, y) => LessThan(y, x)
case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x)
}

/**
* Traverse a condition as a tree and simplify expressions with constraints.
* - This functions assumes that the plan has been normalized using [[normalize()]]
* - On matching [[And]], recursively traverse both children, simplify child expressions with
* propagated constraints from sibling and propagate up union of constraints.
* - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]],
* propagate the constraint.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints.
* - Otherwise, stop traversal and propagate no constraints.
* @param expression expression to be traversed
* @return A tuple including:
* 1. Expression: optionally changed expression after traversal
* 2. Seq[Expression]: propagated constraints
*/
private def reduceWithConstraints(expression: Expression): (Expression, Seq[Expression]) =
expression match {
case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe)
if e.deterministic => (e, Seq(e))
case a @ And(left, right) =>
val (newLeft, leftConstraints) = reduceWithConstraints(left)
val simplifiedRight = reduceWithConstraints(right, leftConstraints)
val (simplifiedNewRight, rightConstraints) = reduceWithConstraints(simplifiedRight)
val simplifiedNewLeft = reduceWithConstraints(newLeft, rightConstraints)
val newAnd = if ((simplifiedNewLeft fastEquals left) &&
(simplifiedNewRight fastEquals right)) {
a
} else {
And(simplifiedNewLeft, simplifiedNewRight)
}
(newAnd, leftConstraints ++ rightConstraints)
case o @ Or(left, right) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = reduceWithConstraints(left)
val (newRight, _) = reduceWithConstraints(right)
val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) {
o
} else {
Or(newLeft, newRight)
}

(newOr, Seq.empty)
case n @ Not(child) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = reduceWithConstraints(child)
val newNot = if (newChild fastEquals child) {
n
} else {
Not(newChild)
}
(newNot, Seq.empty)
case _ => (expression, Seq.empty)
}

private def reduceWithConstraints(expression: Expression, constraints: Seq[Expression]) =
constraints.foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint))

private def planEqual(x: Expression, y: Expression) =
!x.foldable && !y.foldable && x.canonicalized == y.canonicalized

private def valueEqual(x: Expression, y: Expression) =
x.foldable && y.foldable && EqualTo(x, y).eval(EmptyRow).asInstanceOf[Boolean]

private def valueLessThan(x: Expression, y: Expression) =
x.foldable && y.foldable && LessThan(x, y).eval(EmptyRow).asInstanceOf[Boolean]

private def valueLessThanOrEqual(x: Expression, y: Expression) =
x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean]

private def reduceWithConstraint(expression: Expression, constraint: Expression): Expression =
constraint match {
case a LessThan b => expression transformUp {
case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.TrueLiteral
case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.TrueLiteral
case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c LessThanOrEqual d
if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.TrueLiteral
case c LessThanOrEqual d
if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c LessThanOrEqual d
if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.TrueLiteral
case c LessThanOrEqual d
if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral
case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral
case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral
case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral

case c EqualNullSafe d if planEqual(b, d) =>
if (planEqual(a, c)) Literal.FalseLiteral else EqualTo(c, d)
case c EqualNullSafe d if planEqual(b, c) =>
if (planEqual(a, d)) Literal.FalseLiteral else EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, c) =>
if (planEqual(b, d)) Literal.FalseLiteral else EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, d) =>
if (planEqual(b, c)) Literal.FalseLiteral else EqualTo(c, d)
}
case a LessThanOrEqual b => expression transformUp {
case c LessThan d if planEqual(b, d) && valueLessThan(c, a) =>
Literal.TrueLiteral
case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && valueLessThan(b, d) =>
Literal.TrueLiteral
case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c LessThanOrEqual d
if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) =>
Literal.FalseLiteral
case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) =>
EqualTo(c, d)
case c LessThanOrEqual d
if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) =>
Literal.FalseLiteral
case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) =>
EqualTo(c, d)

case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d)
}
case a EqualTo b => expression transformUp {
case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral
case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral
case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral

case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral

case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral
case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral
case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral
case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral

case c EqualNullSafe d if planEqual(b, d) =>
if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d)
case c EqualNullSafe d if planEqual(b, c) =>
if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, d) =>
if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, c) =>
if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d)
}
case a EqualNullSafe b => expression transformUp {
case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral
case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral
case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral

case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d)
case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d)
case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d)
case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d)

case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral
case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral
case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral
case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral
}
case _ => expression
}
}
Loading