Skip to content

Commit db0cfcc

Browse files
committed
Support subexpression elimination in branches of conditional expressions.
1 parent 56c623e commit db0cfcc

File tree

3 files changed

+88
-14
lines changed

3 files changed

+88
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,46 @@ class EquivalentExpressions {
6565
}
6666
}
6767

68+
/**
69+
* Adds only expressions which are common in each of given expressions, in a recursive way.
70+
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
71+
* the common expression `(c + 1)` will be added into `equivalenceMap`.
72+
*/
73+
def addCommonExprs(exprs: Seq[Expression], addFunc: Expression => Boolean = addExpr): Unit = {
74+
var exprSetForAll = ExpressionSet()
75+
76+
addExprTree(exprs.head, (expr: Expression) => {
77+
if (exprSetForAll.contains(expr)) {
78+
true
79+
} else {
80+
exprSetForAll += expr
81+
false
82+
}
83+
})
84+
85+
exprs.tail.foreach { expr =>
86+
var exprSet = ExpressionSet()
87+
addExprTree(expr, (expr: Expression) => {
88+
if (exprSet.contains(expr)) {
89+
true
90+
} else {
91+
exprSet += expr
92+
false
93+
}
94+
})
95+
exprSetForAll = exprSetForAll.intersect(exprSet)
96+
}
97+
98+
exprSetForAll.foreach(addFunc)
99+
}
100+
68101
/**
69102
* Adds the expression to this data structure recursively. Stops if a matching expression
70103
* is found. That is, if `expr` has already been added, its children are not added.
71104
*/
72-
def addExprTree(expr: Expression): Unit = {
105+
def addExprTree(
106+
expr: Expression,
107+
addFunc: Expression => Boolean = addExpr): Unit = {
73108
val skip = expr.isInstanceOf[LeafExpression] ||
74109
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
75110
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
@@ -96,8 +131,21 @@ class EquivalentExpressions {
96131
case other => other.children
97132
}
98133

99-
if (!skip && !addExpr(expr)) {
100-
childrenToRecurse.foreach(addExprTree)
134+
// For some special expressions we cannot just recurse into all of its children, but we can
135+
// recursively add the common expressions shared between all of its children.
136+
def commonChildrenToRecurse: Seq[Seq[Expression]] = expr match {
137+
case i: If => Seq(Seq(i.trueValue, i.falseValue))
138+
case c: CaseWhen =>
139+
val conditions = c.branches.tail.map(_._1)
140+
val values = c.branches.map(_._2) ++ c.elseValue
141+
Seq(conditions, values)
142+
case c: Coalesce => Seq(c.children.tail)
143+
case _ => Nil
144+
}
145+
146+
if (!skip && !addFunc(expr)) {
147+
childrenToRecurse.foreach(addExprTree(_, addFunc))
148+
commonChildrenToRecurse.filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
101149
}
102150
}
103151

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ class CodegenContext extends Logging {
10441044
val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
10451045

10461046
// Add each expression tree and compute the common subexpressions.
1047-
expressions.foreach(equivalentExpressions.addExprTree)
1047+
expressions.foreach(equivalentExpressions.addExprTree(_))
10481048

10491049
// Get all the expressions that appear at least twice and set up the state for subexpression
10501050
// elimination.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,43 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
149149
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
150150
}
151151

152-
test("Children of conditional expressions") {
153-
val condition = And(Literal(true), Literal(false))
152+
test("Children of conditional expressions: If") {
154153
val add = Add(Literal(1), Literal(2))
155-
val ifExpr = If(condition, add, add)
156-
157-
val equivalence = new EquivalentExpressions
158-
equivalence.addExprTree(ifExpr)
159-
// the `add` inside `If` should not be added
160-
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
161-
// only ifExpr and its predicate expression
162-
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
154+
val condition = GreaterThan(add, Literal(3))
155+
156+
val ifExpr1 = If(condition, add, add)
157+
val equivalence1 = new EquivalentExpressions
158+
equivalence1.addExprTree(ifExpr1)
159+
160+
// `add` is in both two branches of `If` and predicate.
161+
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
162+
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
163+
// one-time expressions: only ifExpr and its predicate expression
164+
assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2)
165+
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1))
166+
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition))
167+
168+
// Repeated `add` is only in one branch, so we don't count it.
169+
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
170+
val equivalence2 = new EquivalentExpressions
171+
equivalence2.addExprTree(ifExpr2)
172+
173+
assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0)
174+
assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3)
175+
176+
val ifExpr3 = If(condition, ifExpr1, ifExpr1)
177+
val equivalence3 = new EquivalentExpressions
178+
equivalence3.addExprTree(ifExpr3)
179+
180+
// `add`: 2, `condition`: 2
181+
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2)
182+
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
183+
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition))
184+
185+
// `ifExpr1`, `ifExpr3`
186+
assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2)
187+
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1))
188+
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3))
163189
}
164190
}
165191

0 commit comments

Comments
 (0)