-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-33337][SQL] Support subexpression elimination in branches of conditional expressions #30245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
db0cfcc
cd3776c
9182e3d
cc0648a
16314a9
33f3bd3
b415728
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,11 +65,46 @@ class EquivalentExpressions { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Adds only expressions which are common in each of given expressions, in a recursive way. | ||
| * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, | ||
| * the common expression `(c + 1)` will be added into `equivalenceMap`. | ||
| */ | ||
| def addCommonExprs(exprs: Seq[Expression], addFunc: Expression => Boolean = addExpr): Unit = { | ||
|
||
| var exprSetForAll = ExpressionSet() | ||
|
||
|
|
||
| addExprTree(exprs.head, (expr: Expression) => { | ||
| if (exprSetForAll.contains(expr)) { | ||
| true | ||
| } else { | ||
| exprSetForAll += expr | ||
| false | ||
| } | ||
| }) | ||
|
|
||
| exprs.tail.foreach { expr => | ||
| var exprSet = ExpressionSet() | ||
| addExprTree(expr, (expr: Expression) => { | ||
| if (exprSet.contains(expr)) { | ||
| true | ||
| } else { | ||
| exprSet += expr | ||
| false | ||
| } | ||
| }) | ||
| exprSetForAll = exprSetForAll.intersect(exprSet) | ||
|
||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to handle
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For expression We can merge two blocks, but in the block we need to check if current expression is head expression and do different logic based on the check. I prefer current one since it looks simpler. |
||
|
|
||
| exprSetForAll.foreach(addFunc) | ||
| } | ||
|
|
||
| /** | ||
| * Adds the expression to this data structure recursively. Stops if a matching expression | ||
| * is found. That is, if `expr` has already been added, its children are not added. | ||
| */ | ||
| def addExprTree(expr: Expression): Unit = { | ||
| def addExprTree( | ||
| expr: Expression, | ||
| addFunc: Expression => Boolean = addExpr): Unit = { | ||
| val skip = expr.isInstanceOf[LeafExpression] || | ||
| // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the | ||
| // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. | ||
|
|
@@ -96,8 +131,21 @@ class EquivalentExpressions { | |
| case other => other.children | ||
| } | ||
|
|
||
| if (!skip && !addExpr(expr)) { | ||
| childrenToRecurse.foreach(addExprTree) | ||
| // For some special expressions we cannot just recurse into all of its children, but we can | ||
| // recursively add the common expressions shared between all of its children. | ||
| def commonChildrenToRecurse: Seq[Seq[Expression]] = expr match { | ||
|
||
| case i: If => Seq(Seq(i.trueValue, i.falseValue)) | ||
| case c: CaseWhen => | ||
| val conditions = c.branches.tail.map(_._1) | ||
| val values = c.branches.map(_._2) ++ c.elseValue | ||
| Seq(conditions, values) | ||
| case c: Coalesce => Seq(c.children.tail) | ||
| case _ => Nil | ||
| } | ||
|
|
||
| if (!skip && !addFunc(expr)) { | ||
| childrenToRecurse.foreach(addExprTree(_, addFunc)) | ||
| commonChildrenToRecurse.filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc)) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,20 +146,109 @@ class SubexpressionEliminationSuite extends SparkFunSuite { | |
| equivalence.addExprTree(add) | ||
| // the `two` inside `fallback` should not be added | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode | ||
| } | ||
|
|
||
| test("Children of conditional expressions") { | ||
| val condition = And(Literal(true), Literal(false)) | ||
| test("Children of conditional expressions: If") { | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| val add = Add(Literal(1), Literal(2)) | ||
| val ifExpr = If(condition, add, add) | ||
| val condition = GreaterThan(add, Literal(3)) | ||
|
|
||
| val equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(ifExpr) | ||
| // the `add` inside `If` should not be added | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| // only ifExpr and its predicate expression | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| val ifExpr1 = If(condition, add, add) | ||
| val equivalence1 = new EquivalentExpressions | ||
| equivalence1.addExprTree(ifExpr1) | ||
|
|
||
| // `add` is in both two branches of `If` and predicate. | ||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) | ||
| // one-time expressions: only ifExpr and its predicate expression | ||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use contains method? HashMap can not guarantee the order
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will create a follow-up for making sure it will not possibly flaky. Thanks.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created #30371.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, @leoluan2009 and @viirya . The follow-up is merged to reduce the flakiness. |
||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition)) | ||
|
|
||
| // Repeated `add` is only in one branch, so we don't count it. | ||
| val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) | ||
| val equivalence2 = new EquivalentExpressions | ||
| equivalence2.addExprTree(ifExpr2) | ||
|
|
||
| assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3) | ||
|
|
||
| val ifExpr3 = If(condition, ifExpr1, ifExpr1) | ||
| val equivalence3 = new EquivalentExpressions | ||
| equivalence3.addExprTree(ifExpr3) | ||
|
|
||
| // `add`: 2, `condition`: 2 | ||
| assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition)) | ||
|
|
||
| // `ifExpr1`, `ifExpr3` | ||
| assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) | ||
| assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3)) | ||
| } | ||
|
|
||
| test("Children of conditional expressions: CaseWhen") { | ||
| val add1 = Add(Literal(1), Literal(2)) | ||
| val add2 = Add(Literal(2), Literal(3)) | ||
| val conditions1 = (GreaterThan(add2, Literal(3)), add1) :: | ||
| (GreaterThan(add2, Literal(4)), add1) :: | ||
| (GreaterThan(add2, Literal(5)), add1) :: Nil | ||
|
|
||
| val caseWhenExpr1 = CaseWhen(conditions1, None) | ||
| val equivalence1 = new EquivalentExpressions | ||
| equivalence1.addExprTree(caseWhenExpr1) | ||
|
|
||
| // `add2` is repeatedly in all conditions. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We treat the first condition specially because it is definitely run. So it counts one for For |
||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) | ||
|
|
||
| val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: | ||
| (GreaterThan(add2, Literal(4)), add1) :: | ||
| (GreaterThan(add2, Literal(5)), add1) :: Nil | ||
|
|
||
| val caseWhenExpr2 = CaseWhen(conditions2, None) | ||
| val equivalence2 = new EquivalentExpressions | ||
| equivalence2.addExprTree(caseWhenExpr2) | ||
|
|
||
| // `add2` is repeatedly in all branch values, and first predicate. | ||
| assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1)) | ||
|
|
||
| val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: | ||
| (GreaterThan(add2, Literal(4)), add1) :: | ||
| (GreaterThan(add2, Literal(5)), add1) :: Nil | ||
|
|
||
| val caseWhenExpr3 = CaseWhen(conditions3, None) | ||
| val equivalence3 = new EquivalentExpressions | ||
| equivalence3.addExprTree(caseWhenExpr3) | ||
| assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0) | ||
| } | ||
|
|
||
| test("Children of conditional expressions: Coalesce") { | ||
| val add1 = Add(Literal(1), Literal(2)) | ||
| val add2 = Add(Literal(2), Literal(3)) | ||
| val conditions1 = GreaterThan(add2, Literal(3)) :: | ||
| GreaterThan(add2, Literal(4)) :: | ||
| GreaterThan(add2, Literal(5)) :: Nil | ||
|
|
||
| val coalesceExpr1 = Coalesce(conditions1) | ||
| val equivalence1 = new EquivalentExpressions | ||
| equivalence1.addExprTree(coalesceExpr1) | ||
|
|
||
| // `add2` is repeatedly in all conditions. | ||
| assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) | ||
| assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) | ||
|
|
||
| val conditions2 = GreaterThan(add1, Literal(3)) :: | ||
| GreaterThan(add2, Literal(4)) :: | ||
| GreaterThan(add2, Literal(5)) :: Nil | ||
|
|
||
| val coalesceExpr2 = Coalesce(conditions2) | ||
| val equivalence2 = new EquivalentExpressions | ||
| equivalence2.addExprTree(coalesceExpr2) | ||
|
|
||
| assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0) | ||
| } | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.