Skip to content

Commit 066944c

Browse files
committed
[SPARK-35439][SQL] Children subexpr should come first than parent subexpr
### What changes were proposed in this pull request? This patch sorts equivalent expressions based on their child-parent relation. ### Why are the changes needed? `EquivalentExpressions` maintains a map of equivalent expressions. It is `HashMap` now so the insertion order is not guaranteed to be preserved later. Subexpression elimination relies on retrieving subexpressions from the map. If there is child-parent relationships among the subexpressions, we want the child expressions come first than parent expressions, so we can replace child expressions in parent expressions with subexpression evaluation. For example, we have two different expressions `Add(Literal(1), Literal(2))` and `Add(Literal(3), add)`. Case 1: child subexpr comes first. ```scala addExprTree(add) addExprTree(Add(Literal(3), add)) addExprTree(Add(Literal(3), add)) ``` Case 2: parent subexpr comes first. For this case, we need to sort equivalent expressions. ``` addExprTree(Add(Literal(3), add)) => We add `Add(Literal(3), add)` into the map first, then add `add` into the map addExprTree(add) addExprTree(Add(Literal(3), add)) ``` As we are going to sort equivalent expressions at all, we don't need `LinkedHashMap` but just do sorting. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests. Closes #32586 from viirya/use-listhashmap. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
1 parent cc05daa commit 066944c

4 files changed

Lines changed: 91 additions & 39 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,33 @@ class EquivalentExpressions {
164164
}
165165

166166
/**
167-
* Returns all the equivalent sets of expressions.
167+
* Returns all the equivalent sets of expressions which appear more than given `repeatTimes`
168+
* times.
168169
*/
169-
def getAllEquivalentExprs: Seq[Seq[Expression]] = {
170-
equivalenceMap.values.map(_.toSeq).toSeq
170+
def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
171+
equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
172+
.sortBy(_.head)(new ExpressionContainmentOrdering)
173+
}
174+
175+
/**
176+
* Orders `Expression` by parent/child relations. The child expression is smaller
177+
* than parent expression. If there is child-parent relationships among the subexpressions,
178+
* we want the child expressions come first than parent expressions, so we can replace
179+
* child expressions in parent expressions with subexpression evaluation. Note that
180+
* this is not for general expression ordering. For example, two irrelevant expressions
181+
* will be considered as e1 < e2 and e2 < e1 by this ordering. But for the usage here,
182+
* the order of irrelevant expressions does not matter.
183+
*/
184+
class ExpressionContainmentOrdering extends Ordering[Expression] {
185+
override def compare(x: Expression, y: Expression): Int = {
186+
if (x.semanticEquals(y)) {
187+
0
188+
} else if (x.find(_.semanticEquals(y)).isDefined) {
189+
1
190+
} else {
191+
-1
192+
}
193+
}
171194
}
172195

173196
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
9191

9292
val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]
9393

94-
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
94+
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
9595
commonExprs.foreach { e =>
9696
val expr = e.head
9797
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ class CodegenContext extends Logging {
10461046

10471047
// Get all the expressions that appear at least twice and set up the state for subexpression
10481048
// elimination.
1049-
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
1049+
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
10501050
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))
10511051

10521052
lazy val nonSplitExprCode = {
@@ -1131,7 +1131,7 @@ class CodegenContext extends Logging {
11311131

11321132
// Get all the expressions that appear at least twice and set up the state for subexpression
11331133
// elimination.
1134-
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
1134+
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
11351135
commonExprs.foreach { e =>
11361136
val expr = e.head
11371137
val fnName = freshName("subExpr")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
4747

4848
test("Expression Equivalence - basic") {
4949
val equivalence = new EquivalentExpressions
50-
assert(equivalence.getAllEquivalentExprs.isEmpty)
50+
assert(equivalence.getAllEquivalentExprs().isEmpty)
5151

5252
val oneA = Literal(1)
5353
val oneB = Literal(1)
@@ -72,18 +72,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
7272
assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
7373
assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
7474
assert(equivalence.getEquivalentExprs(twoA).isEmpty)
75-
assert(equivalence.getAllEquivalentExprs.size == 1)
76-
assert(equivalence.getAllEquivalentExprs.head.size == 3)
77-
assert(equivalence.getAllEquivalentExprs.head.contains(oneA))
78-
assert(equivalence.getAllEquivalentExprs.head.contains(oneB))
75+
assert(equivalence.getAllEquivalentExprs().size == 1)
76+
assert(equivalence.getAllEquivalentExprs().head.size == 3)
77+
assert(equivalence.getAllEquivalentExprs().head.contains(oneA))
78+
assert(equivalence.getAllEquivalentExprs().head.contains(oneB))
7979

8080
val add1 = Add(oneA, oneB)
8181
val add2 = Add(oneA, oneB)
8282

8383
equivalence.addExpr(add1)
8484
equivalence.addExpr(add2)
8585

86-
assert(equivalence.getAllEquivalentExprs.size == 2)
86+
assert(equivalence.getAllEquivalentExprs().size == 2)
8787
assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
8888
assert(equivalence.getEquivalentExprs(add2).size == 2)
8989
assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
@@ -103,8 +103,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
103103
equivalence.addExprTree(add2)
104104

105105
// Should only have one equivalence for `one + two`
106-
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
107-
assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4)
106+
assert(equivalence.getAllEquivalentExprs(1).size == 1)
107+
assert(equivalence.getAllEquivalentExprs(1).head.size == 4)
108108

109109
// Set up the expressions
110110
// one * two,
@@ -122,7 +122,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
122122
equivalence.addExprTree(sum)
123123

124124
// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
125-
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
125+
assert(equivalence.getAllEquivalentExprs(1).size == 3)
126126
assert(equivalence.getEquivalentExprs(mul).size == 3)
127127
assert(equivalence.getEquivalentExprs(mul2).size == 3)
128128
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
@@ -134,7 +134,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
134134
val equivalence = new EquivalentExpressions
135135
equivalence.addExpr(sum)
136136
equivalence.addExpr(sum)
137-
assert(equivalence.getAllEquivalentExprs.isEmpty)
137+
assert(equivalence.getAllEquivalentExprs().isEmpty)
138138
}
139139

140140
test("Children of CodegenFallback") {
@@ -146,8 +146,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
146146
val equivalence = new EquivalentExpressions
147147
equivalence.addExprTree(add)
148148
// the `two` inside `fallback` should not be added
149-
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
150-
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
149+
assert(equivalence.getAllEquivalentExprs(1).size == 0)
150+
assert(equivalence.getAllEquivalentExprs().count(_.size == 1) == 3) // add, two, explode
151151
}
152152

153153
test("Children of conditional expressions: If") {
@@ -159,35 +159,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
159159
equivalence1.addExprTree(ifExpr1)
160160

161161
// `add` is in both two branches of `If` and predicate.
162-
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
163-
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add))
162+
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
163+
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add, add))
164164
// one-time expressions: only ifExpr and its predicate expression
165-
assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2)
166-
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1)))
167-
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(condition)))
165+
assert(equivalence1.getAllEquivalentExprs().count(_.size == 1) == 2)
166+
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
167+
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(condition)))
168168

169169
// Repeated `add` is only in one branch, so we don't count it.
170170
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
171171
val equivalence2 = new EquivalentExpressions
172172
equivalence2.addExprTree(ifExpr2)
173173

174-
assert(equivalence2.getAllEquivalentExprs.count(_.size > 1) == 0)
175-
assert(equivalence2.getAllEquivalentExprs.count(_.size == 1) == 3)
174+
assert(equivalence2.getAllEquivalentExprs(1).size == 0)
175+
assert(equivalence2.getAllEquivalentExprs().count(_.size == 1) == 3)
176176

177177
val ifExpr3 = If(condition, ifExpr1, ifExpr1)
178178
val equivalence3 = new EquivalentExpressions
179179
equivalence3.addExprTree(ifExpr3)
180180

181181
// `add`: 2, `condition`: 2
182-
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2)
183-
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(add, add)))
182+
assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 2)
183+
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(add, add)))
184184
assert(
185-
equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(condition, condition)))
185+
equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(condition, condition)))
186186

187187
// `ifExpr1`, `ifExpr3`
188-
assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2)
189-
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1)))
190-
assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr3)))
188+
assert(equivalence3.getAllEquivalentExprs().count(_.size == 1) == 2)
189+
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
190+
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr3)))
191191
}
192192

193193
test("Children of conditional expressions: CaseWhen") {
@@ -202,8 +202,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
202202
equivalence1.addExprTree(caseWhenExpr1)
203203

204204
// `add2` is repeatedly in all conditions.
205-
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
206-
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))
205+
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
206+
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
207207

208208
val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
209209
(GreaterThan(add2, Literal(4)), add1) ::
@@ -214,8 +214,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
214214
equivalence2.addExprTree(caseWhenExpr2)
215215

216216
// `add1` is repeatedly in all branch values, and first predicate.
217-
assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 1)
218-
assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add1, add1))
217+
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 1)
218+
assert(equivalence2.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add1, add1))
219219

220220
// Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
221221
val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
@@ -225,7 +225,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
225225
val caseWhenExpr3 = CaseWhen(conditions3, None)
226226
val equivalence3 = new EquivalentExpressions
227227
equivalence3.addExprTree(caseWhenExpr3)
228-
assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 0)
228+
assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 0)
229229
}
230230

231231
test("Children of conditional expressions: Coalesce") {
@@ -240,8 +240,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
240240
equivalence1.addExprTree(coalesceExpr1)
241241

242242
// `add2` is repeatedly in all conditions.
243-
assert(equivalence1.getAllEquivalentExprs.count(_.size == 2) == 1)
244-
assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add2, add2))
243+
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
244+
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
245245

246246
// Negative case. `add1` and `add2` both are not used in all branches.
247247
val conditions2 = GreaterThan(add1, Literal(3)) ::
@@ -252,7 +252,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
252252
val equivalence2 = new EquivalentExpressions
253253
equivalence2.addExprTree(coalesceExpr2)
254254

255-
assert(equivalence2.getAllEquivalentExprs.count(_.size == 2) == 0)
255+
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0)
256256
}
257257

258258
test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") {
@@ -309,6 +309,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
309309
CodeGenerator.compile(code)
310310
}
311311
}
312+
313+
test("SPARK-35439: Children subexpr should come first than parent subexpr") {
314+
val add = Add(Literal(1), Literal(2))
315+
316+
val equivalence1 = new EquivalentExpressions
317+
318+
equivalence1.addExprTree(add)
319+
assert(equivalence1.getAllEquivalentExprs().head === Seq(add))
320+
321+
equivalence1.addExprTree(Add(Literal(3), add))
322+
assert(equivalence1.getAllEquivalentExprs() ===
323+
Seq(Seq(add, add), Seq(Add(Literal(3), add))))
324+
325+
equivalence1.addExprTree(Add(Literal(3), add))
326+
assert(equivalence1.getAllEquivalentExprs() ===
327+
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
328+
329+
val equivalence2 = new EquivalentExpressions
330+
331+
equivalence2.addExprTree(Add(Literal(3), add))
332+
assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add), Seq(Add(Literal(3), add))))
333+
334+
equivalence2.addExprTree(add)
335+
assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add))))
336+
337+
equivalence2.addExprTree(Add(Literal(3), add))
338+
assert(equivalence2.getAllEquivalentExprs() ===
339+
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
340+
}
312341
}
313342

314343
case class CodegenFallbackExpression(child: Expression)

0 commit comments

Comments
 (0)