From 4111a04e3699078a6db3d347fecf0e5d11483a3c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 16 May 2021 00:39:12 -0700 Subject: [PATCH 1/4] SubExpr elimination should not include redundant child exprs. --- .../expressions/EquivalentExpressions.scala | 11 +++++++++-- .../SubexpressionEliminationSuite.scala | 16 ++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 16 ++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1dfff412d9a8..617e5c0f98f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -90,13 +90,20 @@ class EquivalentExpressions { val exprSetForAll = mutable.Set[Expr]() addExprTree(exprs.head, addExprToSet(_, exprSetForAll)) - val commonExprSet = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => + val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) => val otherExprSet = mutable.Set[Expr]() addExprTree(expr, addExprToSet(_, otherExprSet)) exprSet.intersect(otherExprSet) } - commonExprSet.foreach(expr => addFunc(expr.e)) + // Not all expressions in the set should be added. We should filter out the subexprs. + val commonExprSet = candidateExprs.filter { candidateExpr => + candidateExprs.forall { expr => + expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty + } + } + + commonExprSet.foreach(expr => addExprTree(expr.e, addFunc)) } // There are some special expressions that we should not recurse into all of its children. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 9bfe69b1709d..595399e467d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -309,6 +309,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel CodeGenerator.compile(code) } } + + test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + + "for conditional expressions") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + val add3 = Add(add1, add2) + val condition = (GreaterThan(add3, Literal(3)), add3) :: Nil + + val caseWhenExpr = CaseWhen(condition, None) + val equivalence = new EquivalentExpressions + equivalence.addExprTree(caseWhenExpr) + + val commonExprs = equivalence.getAllEquivalentExprs.filter(_.size > 1) + assert(commonExprs.size == 1) + assert(commonExprs.head === Seq(add3, add3)) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d7d85d43544e..9f22770ca0ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2861,8 +2861,24 @@ class DataFrameSuite extends QueryTest checkAnswer(result, Row(0, 0, 0, 0, 100)) } } + + test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + + "for conditional expressions") { + val myUdf = udf((s: String) => { + KeepCall.countOfCalls += 1 + s + }) + val df = spark.range(5).select(when(functions.length(myUdf($"id")) > 0, + functions.length(myUdf($"id")))) + df.collect() + assert(KeepCall.countOfCalls == 5) + } } case class GroupByKey(a: Int, b: Int) case class Bar2(s: String) + +object KeepCall { + var countOfCalls = 0 +} From ddb911eab7718af3dc8f88d5c910de16fcb4be28 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 16 May 2021 09:47:26 -0700 Subject: [PATCH 2/4] Update test. --- .../org/apache/spark/sql/DataFrameSuite.scala | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9f22770ca0ed..19c905326042 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2864,21 +2864,24 @@ class DataFrameSuite extends QueryTest test("SPARK-35410: SubExpr elimination should not include redundant child exprs " + "for conditional expressions") { - val myUdf = udf((s: String) => { - KeepCall.countOfCalls += 1 + val accum = sparkContext.longAccumulator("call") + val simpleUDF = udf((s: String) => { + accum.add(1) s }) - val df = spark.range(5).select(when(functions.length(myUdf($"id")) > 0, - functions.length(myUdf($"id")))) - df.collect() - assert(KeepCall.countOfCalls == 5) + val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0, + functions.length(simpleUDF($"id")))) + df1.collect() + assert(accum.value == 5) + + val nondeterministicUDF = simpleUDF.asNondeterministic() + val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0, + functions.length(nondeterministicUDF($"id")))) + df2.collect() + assert(accum.value == 15) } } case class GroupByKey(a: Int, b: Int) case class Bar2(s: String) - -object KeepCall { - var countOfCalls = 0 -} From 01a8c02fe92a9614fca72f242c708870cde7e1b0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 16 May 2021 21:50:25 -0700 Subject: [PATCH 3/4] Revise comment. --- .../sql/catalyst/expressions/EquivalentExpressions.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 617e5c0f98f5..63e43c68e4bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -82,7 +82,10 @@ class EquivalentExpressions { /** * Adds only expressions which are common in each of given expressions, in a recursive way. * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, - * the common expression `(c + 1)` will be added into `equivalenceMap`. + * the common expression `(c + 1)` will be added into `equivalenceMap`. Note that if an + * expression and its child expressions are all commonly occurred in each of given expressions, + * we filter out the child expressions. For example, if `((a + b) + c)` and `(a + b)` are + * common expressions, we only add `((a + b) + c)`. */ private def addCommonExprs( exprs: Seq[Expression], From 4278e705b4b85b7c185467eabc64d60e4371489b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 18 May 2021 00:12:08 -0700 Subject: [PATCH 4/4] Update comment. --- .../expressions/EquivalentExpressions.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 63e43c68e4bc..be8ba96a49c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -82,10 +82,14 @@ class EquivalentExpressions { /** * Adds only expressions which are common in each of given expressions, in a recursive way. * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, - * the common expression `(c + 1)` will be added into `equivalenceMap`. Note that if an - * expression and its child expressions are all commonly occurred in each of given expressions, - * we filter out the child expressions. For example, if `((a + b) + c)` and `(a + b)` are - * common expressions, we only add `((a + b) + c)`. + * the common expression `(c + 1)` will be added into `equivalenceMap`. + * + * Note that as we don't know in advance if any child node of an expression will be common + * across all given expressions, we count all child nodes when looking through the given + * expressions. But when we call `addExprTree` to add common expressions into the map, we + * will add recursively the child nodes. So we need to filter the child expressions first. + * For example, if `((a + b) + c)` and `(a + b)` are common expressions, we only add + * `((a + b) + c)`. */ private def addCommonExprs( exprs: Seq[Expression], @@ -99,7 +103,8 @@ class EquivalentExpressions { exprSet.intersect(otherExprSet) } - // Not all expressions in the set should be added. We should filter out the subexprs. + // Not all expressions in the set should be added. We should filter out the related + // children nodes. val commonExprSet = candidateExprs.filter { candidateExpr => candidateExprs.forall { expr => expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty