Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
// Constant folding and strength reduction
TransposeWindow,
NullPropagation,
ConstantPropagation,
ConstraintPropagation,
FoldablePropagation,
OptimizeIn,
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,100 +55,24 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

/**
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
* value in conjunctive [[Expression Expressions]]
* Substitutes expressions which can be statically narrowed by constrains.
* eg.
* {{{
* SELECT * FROM table WHERE i = 5 AND j = i + 3
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
* SELECT * FROM table WHERE i = 5 AND j = i + 3 => SELECT * FROM table WHERE i = 5 AND j = 8
* SELECT * FROM table WHERE i <= 5 AND i = 5 => SELECT * FROM table WHERE i = 5
Copy link
Member

@dongjoon-hyun dongjoon-hyun Apr 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, @peter-toth . I don't think this aims the equal goal. Could you make a separate rule instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filter reduction should be a separate rule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you all for the feedback. The 2 rules seemed similar and easy to combine to me, but I will not mix them then.

Copy link
Contributor Author

@peter-toth peter-toth May 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I separated the 2 rules. I added some improvement to the ConstantPropagation and created a new FilterReduction.

Copy link
Contributor Author

@peter-toth peter-toth May 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved all new filter reduction related logic to FilterReduction and extracted constant propagation enhancements to a separate PR.

* SELECT * FROM table WHERE i < j AND ... AND i > j => SELECT * FROM table WHERE false
* }}}
*
* Approach used:
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
object ConstraintPropagation extends Rule[LogicalPlan] with ConstraintHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
val newCondition = simplifyWithConstraints(f.condition)
if (newCondition fastEquals f.condition) {
f
} else {
f.copy(condition = newCondition)
}
}

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]

/**
* Traverse a condition as a tree and replace attributes with constant values.
* - On matching [[And]], recursively traverse each children and get propagated mappings.
* If the current node is not child of another [[And]], replace all occurrences of the
* attributes with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
* of attribute => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant values in children
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean)
: (Option[Expression], EqualityPredicates) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
(None, Seq(((left, right), e)))
case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
}
(newSelf, equalityPredicates)
case o: Or =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true)
val (newRight, _) = traverse(o.right, replaceChildren = true)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
} else {
None
}
(newSelf, Seq.empty)
case n: Not =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
}

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
}
}
}

/**
Expand Down Expand Up @@ -389,6 +313,7 @@ object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper {
case q: LogicalPlan => q transformExpressionsUp {
// True with equality
case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
case a EqualNullSafe b if a.foldable || b.foldable => EqualTo(a, b)
case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
TrueLiteral
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ abstract class UnaryNode extends LogicalPlan {
var allConstraints = child.constraints.asInstanceOf[Set[Expression]]
projectList.foreach {
case a @ Alias(l: Literal, _) =>
allConstraints += EqualNullSafe(a.toAttribute, l)
allConstraints +=
(if (l.nullable) EqualNullSafe(a.toAttribute, l) else EqualTo(a.toAttribute, l))
case a @ Alias(e, _) =>
// For every alias in `projectList`, replace the reference in constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,253 @@ trait ConstraintHelper {
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

def simplifyWithConstraints(expression: Expression): Expression =
simplify(normalize(expression))._1

private def normalize(expression: Expression) = expression transform {
case GreaterThan(x, y) => LessThan(y, x)
case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x)
}

/**
* Traverse a condition as a tree and simplify expressions with constraints.
* - On matching [[And]], recursively traverse both children, simplify child expressions with
* propagated constraints from sibling and propagate up union of constraints.
* - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]],
* [[GreaterThan]] or [[GreaterThanOrEqual]] propagate the constraint.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints.
* - Otherwise, stop traversal and propagate no constraints.
* @param expression expression to be traversed
* @return A tuple including:
* 1. Expression: optionally changed condition after traversal
* 2. Seq[Expression]: propagated constraints
*/
private def simplify(expression: Expression): (Expression, Seq[Expression]) = expression match {
case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe | _: GreaterThan |
_: GreaterThanOrEqual )
if e.deterministic => (e, Seq(e))
case a @ And(left, right) =>
val (newLeft, leftConstraints) = simplify(left)
val simplifiedRight = simplify(right, leftConstraints)
val (simplifiedNewRight, rightConstraints) = simplify(simplifiedRight)
val simplifiedNewLeft = simplify(newLeft, rightConstraints)
val newAnd = if ((simplifiedNewLeft fastEquals left) &&
(simplifiedNewRight fastEquals right)) {
a
} else {
And(simplifiedNewLeft, simplifiedNewRight)
}
(newAnd, leftConstraints ++ rightConstraints)
case o @ Or(left, right) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = simplify(left)
val (newRight, _) = simplify(right)
val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) {
o
} else {
Or(newLeft, newRight)
}

(newOr, Seq.empty)
case n @ Not(child) =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = simplify(child)
val newNot = if (newChild fastEquals child) {
n
} else {
Not(newChild)
}
(newNot, Seq.empty)
case _ => (expression, Seq.empty)
}

private def simplify(expression: Expression, constraints: Seq[Expression]): Expression =
constraints.foldLeft(expression)((e, constraint) => simplify(e, constraint))

private def planEqual(x: Expression, y: Expression) =
!x.foldable && !y.foldable && x.canonicalized == y.canonicalized

private def valueEqual(x: Expression, y: Expression) =
x.foldable && y.foldable && EqualTo(x, y).eval(EmptyRow).asInstanceOf[Boolean]

private def valueLessThan(x: Expression, y: Expression) =
x.foldable && y.foldable && LessThan(x, y).eval(EmptyRow).asInstanceOf[Boolean]

private def valueLessThanOrEqual(x: Expression, y: Expression) =
x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean]

private def simplify(expression: Expression, constraint: Expression): Expression =
constraint match {
case a LessThan b => expression transformUp {
case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.TrueLiteral
case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.TrueLiteral
case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c LessThanOrEqual d
if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.TrueLiteral
case c LessThanOrEqual d
if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c LessThanOrEqual d
if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.TrueLiteral
case c LessThanOrEqual d
if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.FalseLiteral
case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.FalseLiteral
case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) =>
Literal.FalseLiteral
case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) =>
Literal.FalseLiteral
case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) =>
Literal.FalseLiteral
case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) =>
Literal.FalseLiteral

case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d)
}
case a LessThanOrEqual b => expression transformUp {
case c LessThan d if planEqual(b, d) && valueLessThan(c, a) =>
Literal.TrueLiteral
case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && valueLessThan(b, d) =>
Literal.TrueLiteral
case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) =>
Literal.FalseLiteral

case c LessThanOrEqual d
if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) =>
Literal.FalseLiteral
case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) =>
EqualTo(c, d)
case c LessThanOrEqual d
if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) =>
Literal.FalseLiteral
case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) =>
EqualTo(c, d)

case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) =>
Literal.FalseLiteral
case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) =>
Literal.FalseLiteral
case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) =>
Literal.FalseLiteral
case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) =>
Literal.FalseLiteral

case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d)
}
case a EqualTo b =>
if (b.foldable) {
expression transformUp { case c if planEqual(a, c) => b }
} else if (a.foldable) {
expression transformUp { case c if planEqual(b, c) => a }
} else {
expression transformUp {
case c LessThan d if planEqual(b, d) && planEqual(a, c) =>
Literal.FalseLiteral
case c LessThan d if planEqual(b, c) && planEqual(a, d) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, d) && planEqual(b, c) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && planEqual(b, d) =>
Literal.FalseLiteral

case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) =>
Literal.TrueLiteral
case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) =>
Literal.TrueLiteral

case c EqualTo d if planEqual(b, d) && planEqual(a, c) =>
Literal.TrueLiteral
case c EqualTo d if planEqual(b, c) && planEqual(a, d) =>
Literal.TrueLiteral
case c EqualTo d if planEqual(a, d) && planEqual(b, c) =>
Literal.TrueLiteral
case c EqualTo d if planEqual(a, c) && planEqual(b, d) =>
Literal.TrueLiteral

case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) =>
Literal.TrueLiteral
case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) =>
Literal.TrueLiteral
case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) =>
Literal.TrueLiteral
case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) =>
Literal.TrueLiteral

case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d)
case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d)
}
}
case a EqualNullSafe b =>
if (b.foldable) {
expression transformUp { case c if planEqual(a, c) => b }
} else if (a.foldable) {
expression transformUp { case c if planEqual(b, c) => a }
} else {
expression transformUp {
case c LessThan d if planEqual(b, d) && planEqual(a, c) =>
Literal.FalseLiteral
case c LessThan d if planEqual(b, c) && planEqual(d, a) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, d) && planEqual(b, c) =>
Literal.FalseLiteral
case c LessThan d if planEqual(a, c) && planEqual(d, b) =>
Literal.FalseLiteral

case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) =>
EqualTo(c, d)
case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) =>
EqualTo(c, d)
case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) =>
EqualTo(c, d)
case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) =>
EqualTo(c, d)

case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) =>
Literal.TrueLiteral
case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) =>
Literal.TrueLiteral
case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) =>
Literal.TrueLiteral
case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) =>
Literal.TrueLiteral
}
}
case _ => expression
}
}
Loading