diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1dfff412d9a8e..364f546ffcbe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -164,10 +164,33 @@ class EquivalentExpressions { } /** - * Returns all the equivalent sets of expressions. + * Returns all the equivalent sets of expressions which appear more than given `repeatTimes` + * times. */ - def getAllEquivalentExprs: Seq[Seq[Expression]] = { - equivalenceMap.values.map(_.toSeq).toSeq + def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq + .sortBy(_.head)(new ExpressionContainmentOrdering) + } + + /** + * Orders `Expression` by parent/child relations. The child expression is smaller + * than parent expression. If there is child-parent relationships among the subexpressions, + * we want the child expressions come first than parent expressions, so we can replace + * child expressions in parent expressions with subexpression evaluation. Note that + * this is not for general expression ordering. For example, two irrelevant expressions + * will be considered as e1 < e2 and e2 < e1 by this ordering. But for the usage here, + * the order of irrelevant expressions does not matter. + */ + class ExpressionContainmentOrdering extends Ordering[Expression] { + override def compare(x: Expression, y: Expression): Int = { + if (x.semanticEquals(y)) { + 0 + } else if (x.find(_.semanticEquals(y)).isDefined) { + 1 + } else { + -1 + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala index 0f224fefe3911..7886b657932c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -91,7 +91,7 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) { val proxyMap = new IdentityHashMap[Expression, ExpressionProxy] - val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) commonExprs.foreach { e => val expr = e.head val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f1fc718432c56..b0d9c36023ab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1048,7 +1048,7 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. - val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) lazy val commonExprVals = commonExprs.map(_.head.genCode(this)) lazy val nonSplitExprCode = { @@ -1133,7 +1133,7 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. - val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) commonExprs.foreach { e => val expr = e.head val fnName = freshName("subExpr") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 9bfe69b1709d2..bdb08de6d7276 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -47,7 +47,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel test("Expression Equivalence - basic") { val equivalence = new EquivalentExpressions - assert(equivalence.getAllEquivalentExprs.isEmpty) + assert(equivalence.getAllEquivalentExprs().isEmpty) val oneA = Literal(1) val oneB = Literal(1) @@ -72,10 +72,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) assert(equivalence.getEquivalentExprs(twoA).isEmpty) - assert(equivalence.getAllEquivalentExprs.size == 1) - assert(equivalence.getAllEquivalentExprs.head.size == 3) - assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) - assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + assert(equivalence.getAllEquivalentExprs().size == 1) + assert(equivalence.getAllEquivalentExprs().head.size == 3) + assert(equivalence.getAllEquivalentExprs().head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs().head.contains(oneB)) val add1 = Add(oneA, oneB) val add2 = Add(oneA, oneB) @@ -83,7 +83,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExpr(add1) equivalence.addExpr(add2) - assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getAllEquivalentExprs().size == 2) assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) assert(equivalence.getEquivalentExprs(add2).size == 2) assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) @@ -103,8 +103,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(add2) // Should only have one equivalence for `one + two` - assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1) - assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + assert(equivalence.getAllEquivalentExprs(1).size == 1) + assert(equivalence.getAllEquivalentExprs(1).head.size == 4) // Set up the expressions // one * two, @@ -122,7 +122,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence.addExprTree(sum) // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found - assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3) + assert(equivalence.getAllEquivalentExprs(1).size == 3) assert(equivalence.getEquivalentExprs(mul).size == 3) assert(equivalence.getEquivalentExprs(mul2).size == 3) assert(equivalence.getEquivalentExprs(sqrt).size == 2) @@ -134,7 +134,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExpr(sum) equivalence.addExpr(sum) - assert(equivalence.getAllEquivalentExprs.isEmpty) + assert(equivalence.getAllEquivalentExprs().isEmpty) } test("Children of CodegenFallback") { @@ -146,8 +146,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence = new EquivalentExpressions equivalence.addExprTree(add) // the `two` inside `fallback` should not be added - assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) - assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode + assert(equivalence.getAllEquivalentExprs(1).size == 0) + assert(equivalence.getAllEquivalentExprs().count(_.size == 1) == 3) // add, two, explode } test("Children of conditional expressions: If") { @@ -159,35 +159,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(ifExpr1) // `add` is in both two branches of `If` and predicate. - assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) + assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1) + assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add, add)) // one-time expressions: only ifExpr and its predicate expression - assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1))) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(condition))) + assert(equivalence1.getAllEquivalentExprs().count(_.size == 1) == 2) + assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1))) + assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(condition))) // Repeated `add` is only in one branch, so we don't count it. val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(ifExpr2) - assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0) - assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3) + assert(equivalence2.getAllEquivalentExprs(1).size == 0) + assert(equivalence2.getAllEquivalentExprs().count(_.size == 1) == 3) val ifExpr3 = If(condition, ifExpr1, ifExpr1) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(ifExpr3) // `add`: 2, `condition`: 2 - assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(add, add))) + assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 2) + assert(equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(add, add))) assert( - equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(condition, condition))) + equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(condition, condition))) // `ifExpr1`, `ifExpr3` - assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1))) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr3))) + assert(equivalence3.getAllEquivalentExprs().count(_.size == 1) == 2) + assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1))) + assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr3))) } test("Children of conditional expressions: CaseWhen") { @@ -202,8 +202,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(caseWhenExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) + assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1) + assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2)) val conditions2 = (GreaterThan(add1, Literal(3)), add1) :: (GreaterThan(add2, Literal(4)), add1) :: @@ -214,8 +214,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence2.addExprTree(caseWhenExpr2) // `add1` is repeatedly in all branch values, and first predicate. - assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1) - assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1)) + assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 1) + assert(equivalence2.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add1, add1)) // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values. val conditions3 = (GreaterThan(add1, Literal(3)), add2) :: @@ -225,7 +225,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val caseWhenExpr3 = CaseWhen(conditions3, None) val equivalence3 = new EquivalentExpressions equivalence3.addExprTree(caseWhenExpr3) - assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0) + assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 0) } test("Children of conditional expressions: Coalesce") { @@ -240,8 +240,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel equivalence1.addExprTree(coalesceExpr1) // `add2` is repeatedly in all conditions. - assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2)) + assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1) + assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2)) // Negative case. `add1` and `add2` both are not used in all branches. val conditions2 = GreaterThan(add1, Literal(3)) :: @@ -252,7 +252,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(coalesceExpr2) - assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0) + assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0) } test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") { @@ -309,6 +309,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel CodeGenerator.compile(code) } } + + test("SPARK-35439: Children subexpr should come first than parent subexpr") { + val add = Add(Literal(1), Literal(2)) + + val equivalence1 = new EquivalentExpressions + + equivalence1.addExprTree(add) + assert(equivalence1.getAllEquivalentExprs().head === Seq(add)) + + equivalence1.addExprTree(Add(Literal(3), add)) + assert(equivalence1.getAllEquivalentExprs() === + Seq(Seq(add, add), Seq(Add(Literal(3), add)))) + + equivalence1.addExprTree(Add(Literal(3), add)) + assert(equivalence1.getAllEquivalentExprs() === + Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add)))) + + val equivalence2 = new EquivalentExpressions + + equivalence2.addExprTree(Add(Literal(3), add)) + assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add), Seq(Add(Literal(3), add)))) + + equivalence2.addExprTree(add) + assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add)))) + + equivalence2.addExprTree(Add(Literal(3), add)) + assert(equivalence2.getAllEquivalentExprs() === + Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add)))) + } } case class CodegenFallbackExpression(child: Expression)