Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,46 @@ class EquivalentExpressions {
}
}

/**
* Adds only expressions which are common in each of given expressions, in a recursive way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
* the common expression `(c + 1)` will be added into `equivalenceMap`.
*/
def addCommonExprs(exprs: Seq[Expression], addFunc: Expression => Boolean = addExpr): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be private as well.

var exprSetForAll = ExpressionSet()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question; Any reason to use ExpressionSet instead of mutable.Set[Expr]? I thought we can write it like this;

  def addCommonExprs(exprs: Seq[Expression], addFunc: Expression => Boolean = addExpr): Unit = {
    val exprSetForAll = mutable.Set[Expr]()

    addExprTree(exprs.head, (expr: Expression) => {
      if (expr.deterministic) {
        val e = Expr(expr)
        if (exprSetForAll.contains(e)) {
          true
        } else {
          exprSetForAll += e
          false
        }
      } else {
        false
      }
    })

    val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
      val otherExprSet = mutable.Set[Expr]()
      addExprTree(expr, (expr: Expression) => {
        if (expr.deterministic) {
          val e = Expr(expr)
          if (otherExprSet.contains(e)) {
            true
          } else {
            otherExprSet += e
            false
          }
        } else {
          false
        }
      })
      exprSet.intersect(otherExprSet)
    }

    commonExprSet.foreach(expr => addFunc(expr.e))
  }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They actually the same effect. Expr compares expressions by semanticeEquals which compares canonicalized expressions. ExpressionSet compares canonicalized internally too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the overhead? It seems the current code clones a whole expr set per update though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good point. ExpressionSet has an overhead of maintaining originals.


addExprTree(exprs.head, (expr: Expression) => {
if (exprSetForAll.contains(expr)) {
true
} else {
exprSetForAll += expr
false
}
})

exprs.tail.foreach { expr =>
var exprSet = ExpressionSet()
addExprTree(expr, (expr: Expression) => {
if (exprSet.contains(expr)) {
true
} else {
exprSet += expr
false
}
})
exprSetForAll = exprSetForAll.intersect(exprSet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the order of exprs affect the result?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we take the intersection of all expressions. So the order doesn't matter.

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle head and tail seperately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For expression head, we add underlying expressions into exprSetForAll set. But for expressions in tail, we keep intersect between exprSetForAll and exprSet.

We can merge two blocks, but in the block we need to check if current expression is head expression and do different logic based on the check.

I prefer current one since it looks simpler.


exprSetForAll.foreach(addFunc)
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
*/
def addExprTree(expr: Expression): Unit = {
def addExprTree(
expr: Expression,
addFunc: Expression => Boolean = addExpr): Unit = {
val skip = expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
Expand All @@ -96,8 +131,21 @@ class EquivalentExpressions {
case other => other.children
}

if (!skip && !addExpr(expr)) {
childrenToRecurse.foreach(addExprTree)
// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
def commonChildrenToRecurse: Seq[Seq[Expression]] = expr match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. Although this is used only here, can we declare this outside of this function as a private method? Currently, addExprTree seems to grow unnecessarily.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

case i: If => Seq(Seq(i.trueValue, i.falseValue))
case c: CaseWhen =>
val conditions = c.branches.tail.map(_._1)
val values = c.branches.map(_._2) ++ c.elseValue
Seq(conditions, values)
case c: Coalesce => Seq(c.children.tail)
case _ => Nil
}

if (!skip && !addFunc(expr)) {
childrenToRecurse.foreach(addExprTree(_, addFunc))
commonChildrenToRecurse.filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ class CodegenContext extends Logging {
val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree)
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,109 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
equivalence.addExprTree(add)
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}

test("Children of conditional expressions") {
val condition = And(Literal(true), Literal(false))
test("Children of conditional expressions: If") {
val add = Add(Literal(1), Literal(2))
val ifExpr = If(condition, add, add)
val condition = GreaterThan(add, Literal(3))

val equivalence = new EquivalentExpressions
equivalence.addExprTree(ifExpr)
// the `add` inside `If` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
// only ifExpr and its predicate expression
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
val ifExpr1 = If(condition, add, add)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(ifExpr1)

// `add` is in both two branches of `If` and predicate.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
// one-time expressions: only ifExpr and its predicate expression
assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use contains method? HashMap can not guarantee the order

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will create a follow-up for making sure it will not possibly flaky. Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #30371.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @leoluan2009 and @viirya . The follow-up is merged to reduce the flakiness.

assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition))

// Repeated `add` is only in one branch, so we don't count it.
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(ifExpr2)

assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3)

val ifExpr3 = If(condition, ifExpr1, ifExpr1)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(ifExpr3)

// `add`: 2, `condition`: 2
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2)
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition))

// `ifExpr1`, `ifExpr3`
assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2)
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1))
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3))
}

test("Children of conditional expressions: CaseWhen") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val conditions1 = (GreaterThan(add2, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil

val caseWhenExpr1 = CaseWhen(conditions1, None)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(caseWhenExpr1)

// `add2` is repeatedly in all conditions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add1 is also repeated. Why it's not included?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We treat the first condition specially because it is definitely run. So it counts one for add2. Other conditions all contain add2 so it counts for one. That is where the count 2 comes from for add2.

For add1, although all values contain it, it is definitely run, so we count it one. If no other expression contains add1, we don't extract subexpression for add1 as it will run just once (we only run one value of CaseWhen).

assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))

val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil

val caseWhenExpr2 = CaseWhen(conditions2, None)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(caseWhenExpr2)

// `add2` is repeatedly in all branch values, and first predicate.
assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1))

val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil

val caseWhenExpr3 = CaseWhen(conditions3, None)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(caseWhenExpr3)
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0)
}

test("Children of conditional expressions: Coalesce") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
val conditions1 = GreaterThan(add2, Literal(3)) ::
GreaterThan(add2, Literal(4)) ::
GreaterThan(add2, Literal(5)) :: Nil

val coalesceExpr1 = Coalesce(conditions1)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(coalesceExpr1)

// `add2` is repeatedly in all conditions.
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))

val conditions2 = GreaterThan(add1, Literal(3)) ::
GreaterThan(add2, Literal(4)) ::
GreaterThan(add2, Literal(5)) :: Nil

val coalesceExpr2 = Coalesce(conditions2)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(coalesceExpr2)

assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0)
}
}

Expand Down