Skip to content

Commit 9e1b204

Browse files
viiryadongjoon-hyun
authored andcommitted
[SPARK-35410][SQL] SubExpr elimination should not include redundant children exprs in conditional expression
### What changes were proposed in this pull request? This patch fixes a bug when dealing with common expressions in conditional expressions such as `CaseWhen` during subexpression elimination. For example, previously we find common expressions among conditions of `CaseWhen`, but children expressions are also counted into. We should not count these children expressions as common expressions. ### Why are the changes needed? If the redundant children expressions are counted as common expressions too, they will be redundantly evaluated and miss the subexpression elimination opportunity. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests. Closes #32559 from viirya/SPARK-35410. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent d5868eb commit 9e1b204

3 files changed

Lines changed: 52 additions & 2 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,35 @@ class EquivalentExpressions {
8383
* Adds only expressions which are common in each of given expressions, in a recursive way.
8484
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`,
8585
* the common expression `(c + 1)` will be added into `equivalenceMap`.
86+
*
87+
* Note that as we don't know in advance if any child node of an expression will be common
88+
* across all given expressions, we count all child nodes when looking through the given
89+
* expressions. But when we call `addExprTree` to add common expressions into the map, we
90+
* will add recursively the child nodes. So we need to filter the child expressions first.
91+
* For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add
92+
* `((a + b) + c)`.
8693
*/
8794
private def addCommonExprs(
8895
exprs: Seq[Expression],
8996
addFunc: Expression => Boolean = addExpr): Unit = {
9097
val exprSetForAll = mutable.Set[Expr]()
9198
addExprTree(exprs.head, addExprToSet(_, exprSetForAll))
9299

93-
val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
100+
val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
94101
val otherExprSet = mutable.Set[Expr]()
95102
addExprTree(expr, addExprToSet(_, otherExprSet))
96103
exprSet.intersect(otherExprSet)
97104
}
98105

99-
commonExprSet.foreach(expr => addFunc(expr.e))
106+
// Not all expressions in the set should be added. We should filter out the related
107+
// children nodes.
108+
val commonExprSet = candidateExprs.filter { candidateExpr =>
109+
candidateExprs.forall { expr =>
110+
expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
111+
}
112+
}
113+
114+
commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
100115
}
101116

102117
// There are some special expressions that we should not recurse into all of its children.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
310310
}
311311
}
312312

313+
test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
314+
"for conditional expressions") {
315+
val add1 = Add(Literal(1), Literal(2))
316+
val add2 = Add(Literal(2), Literal(3))
317+
val add3 = Add(add1, add2)
318+
val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil
319+
320+
val caseWhenExpr = CaseWhen(condition, None)
321+
val equivalence = new EquivalentExpressions
322+
equivalence.addExprTree(caseWhenExpr)
323+
324+
val commonExprs = equivalence.getAllEquivalentExprs(1)
325+
assert(commonExprs.size == 1)
326+
assert(commonExprs.head === Seq(add3, add3))
327+
}
328+
313329
test("SPARK-35439: Children subexpr should come first than parent subexpr") {
314330
val add = Add(Literal(1), Literal(2))
315331

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,25 @@ class DataFrameSuite extends QueryTest
28612861
checkAnswer(result, Row(0, 0, 0, 0, 100))
28622862
}
28632863
}
2864+
2865+
test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
2866+
"for conditional expressions") {
2867+
val accum = sparkContext.longAccumulator("call")
2868+
val simpleUDF = udf((s: String) => {
2869+
accum.add(1)
2870+
s
2871+
})
2872+
val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0,
2873+
functions.length(simpleUDF($"id"))))
2874+
df1.collect()
2875+
assert(accum.value == 5)
2876+
2877+
val nondeterministicUDF = simpleUDF.asNondeterministic()
2878+
val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0,
2879+
functions.length(nondeterministicUDF($"id"))))
2880+
df2.collect()
2881+
assert(accum.value == 15)
2882+
}
28642883
}
28652884

28662885
case class GroupByKey(a: Int, b: Int)

0 commit comments

Comments
 (0)