@@ -47,7 +47,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
4747
4848 test(" Expression Equivalence - basic" ) {
4949 val equivalence = new EquivalentExpressions
50- assert(equivalence.getAllEquivalentExprs.isEmpty)
50+ assert(equivalence.getAllEquivalentExprs() .isEmpty)
5151
5252 val oneA = Literal (1 )
5353 val oneB = Literal (1 )
@@ -72,18 +72,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
7272 assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
7373 assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
7474 assert(equivalence.getEquivalentExprs(twoA).isEmpty)
75- assert(equivalence.getAllEquivalentExprs.size == 1 )
76- assert(equivalence.getAllEquivalentExprs.head.size == 3 )
77- assert(equivalence.getAllEquivalentExprs.head.contains(oneA))
78- assert(equivalence.getAllEquivalentExprs.head.contains(oneB))
75+ assert(equivalence.getAllEquivalentExprs() .size == 1 )
76+ assert(equivalence.getAllEquivalentExprs() .head.size == 3 )
77+ assert(equivalence.getAllEquivalentExprs() .head.contains(oneA))
78+ assert(equivalence.getAllEquivalentExprs() .head.contains(oneB))
7979
8080 val add1 = Add (oneA, oneB)
8181 val add2 = Add (oneA, oneB)
8282
8383 equivalence.addExpr(add1)
8484 equivalence.addExpr(add2)
8585
86- assert(equivalence.getAllEquivalentExprs.size == 2 )
86+ assert(equivalence.getAllEquivalentExprs() .size == 2 )
8787 assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
8888 assert(equivalence.getEquivalentExprs(add2).size == 2 )
8989 assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
@@ -103,8 +103,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
103103 equivalence.addExprTree(add2)
104104
105105 // Should only have one equivalence for `one + two`
106- assert(equivalence.getAllEquivalentExprs.count(_.size > 1 ) == 1 )
107- assert(equivalence.getAllEquivalentExprs.filter(_.size > 1 ).head.size == 4 )
106+ assert(equivalence.getAllEquivalentExprs( 1 ).size == 1 )
107+ assert(equivalence.getAllEquivalentExprs( 1 ).head.size == 4 )
108108
109109 // Set up the expressions
110110 // one * two,
@@ -122,7 +122,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
122122 equivalence.addExprTree(sum)
123123
124124 // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
125- assert(equivalence.getAllEquivalentExprs.count(_.size > 1 ) == 3 )
125+ assert(equivalence.getAllEquivalentExprs( 1 ).size == 3 )
126126 assert(equivalence.getEquivalentExprs(mul).size == 3 )
127127 assert(equivalence.getEquivalentExprs(mul2).size == 3 )
128128 assert(equivalence.getEquivalentExprs(sqrt).size == 2 )
@@ -134,7 +134,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
134134 val equivalence = new EquivalentExpressions
135135 equivalence.addExpr(sum)
136136 equivalence.addExpr(sum)
137- assert(equivalence.getAllEquivalentExprs.isEmpty)
137+ assert(equivalence.getAllEquivalentExprs() .isEmpty)
138138 }
139139
140140 test(" Children of CodegenFallback" ) {
@@ -146,8 +146,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
146146 val equivalence = new EquivalentExpressions
147147 equivalence.addExprTree(add)
148148 // the `two` inside `fallback` should not be added
149- assert(equivalence.getAllEquivalentExprs.count(_.size > 1 ) == 0 )
150- assert(equivalence.getAllEquivalentExprs.count(_.size == 1 ) == 3 ) // add, two, explode
149+ assert(equivalence.getAllEquivalentExprs( 1 ).size == 0 )
150+ assert(equivalence.getAllEquivalentExprs() .count(_.size == 1 ) == 3 ) // add, two, explode
151151 }
152152
153153 test(" Children of conditional expressions: If" ) {
@@ -159,35 +159,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
159159 equivalence1.addExprTree(ifExpr1)
160160
161161 // `add` is in both two branches of `If` and predicate.
162- assert(equivalence1.getAllEquivalentExprs.count(_.size == 2 ) == 1 )
163- assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2 ).head == Seq (add, add))
162+ assert(equivalence1.getAllEquivalentExprs() .count(_.size == 2 ) == 1 )
163+ assert(equivalence1.getAllEquivalentExprs() .filter(_.size == 2 ).head == Seq (add, add))
164164 // one-time expressions: only ifExpr and its predicate expression
165- assert(equivalence1.getAllEquivalentExprs.count(_.size == 1 ) == 2 )
166- assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1 ).contains(Seq (ifExpr1)))
167- assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1 ).contains(Seq (condition)))
165+ assert(equivalence1.getAllEquivalentExprs() .count(_.size == 1 ) == 2 )
166+ assert(equivalence1.getAllEquivalentExprs() .filter(_.size == 1 ).contains(Seq (ifExpr1)))
167+ assert(equivalence1.getAllEquivalentExprs() .filter(_.size == 1 ).contains(Seq (condition)))
168168
169169 // Repeated `add` is only in one branch, so we don't count it.
170170 val ifExpr2 = If (condition, Add (Literal (1 ), Literal (3 )), Add (add, add))
171171 val equivalence2 = new EquivalentExpressions
172172 equivalence2.addExprTree(ifExpr2)
173173
174- assert(equivalence2.getAllEquivalentExprs.count(_.size > 1 ) == 0 )
175- assert(equivalence2.getAllEquivalentExprs.count(_.size == 1 ) == 3 )
174+ assert(equivalence2.getAllEquivalentExprs( 1 ).size == 0 )
175+ assert(equivalence2.getAllEquivalentExprs() .count(_.size == 1 ) == 3 )
176176
177177 val ifExpr3 = If (condition, ifExpr1, ifExpr1)
178178 val equivalence3 = new EquivalentExpressions
179179 equivalence3.addExprTree(ifExpr3)
180180
181181 // `add`: 2, `condition`: 2
182- assert(equivalence3.getAllEquivalentExprs.count(_.size == 2 ) == 2 )
183- assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2 ).contains(Seq (add, add)))
182+ assert(equivalence3.getAllEquivalentExprs() .count(_.size == 2 ) == 2 )
183+ assert(equivalence3.getAllEquivalentExprs() .filter(_.size == 2 ).contains(Seq (add, add)))
184184 assert(
185- equivalence3.getAllEquivalentExprs.filter(_.size == 2 ).contains(Seq (condition, condition)))
185+ equivalence3.getAllEquivalentExprs() .filter(_.size == 2 ).contains(Seq (condition, condition)))
186186
187187 // `ifExpr1`, `ifExpr3`
188- assert(equivalence3.getAllEquivalentExprs.count(_.size == 1 ) == 2 )
189- assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1 ).contains(Seq (ifExpr1)))
190- assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1 ).contains(Seq (ifExpr3)))
188+ assert(equivalence3.getAllEquivalentExprs() .count(_.size == 1 ) == 2 )
189+ assert(equivalence3.getAllEquivalentExprs() .filter(_.size == 1 ).contains(Seq (ifExpr1)))
190+ assert(equivalence3.getAllEquivalentExprs() .filter(_.size == 1 ).contains(Seq (ifExpr3)))
191191 }
192192
193193 test(" Children of conditional expressions: CaseWhen" ) {
@@ -202,8 +202,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
202202 equivalence1.addExprTree(caseWhenExpr1)
203203
204204 // `add2` is repeatedly in all conditions.
205- assert(equivalence1.getAllEquivalentExprs.count(_.size == 2 ) == 1 )
206- assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2 ).head == Seq (add2, add2))
205+ assert(equivalence1.getAllEquivalentExprs() .count(_.size == 2 ) == 1 )
206+ assert(equivalence1.getAllEquivalentExprs() .filter(_.size == 2 ).head == Seq (add2, add2))
207207
208208 val conditions2 = (GreaterThan (add1, Literal (3 )), add1) ::
209209 (GreaterThan (add2, Literal (4 )), add1) ::
@@ -214,8 +214,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
214214 equivalence2.addExprTree(caseWhenExpr2)
215215
216216 // `add1` is repeatedly in all branch values, and first predicate.
217- assert(equivalence2.getAllEquivalentExprs.count(_.size == 2 ) == 1 )
218- assert(equivalence2.getAllEquivalentExprs.filter(_.size == 2 ).head == Seq (add1, add1))
217+ assert(equivalence2.getAllEquivalentExprs() .count(_.size == 2 ) == 1 )
218+ assert(equivalence2.getAllEquivalentExprs() .filter(_.size == 2 ).head == Seq (add1, add1))
219219
220220 // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
221221 val conditions3 = (GreaterThan (add1, Literal (3 )), add2) ::
@@ -225,7 +225,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
225225 val caseWhenExpr3 = CaseWhen (conditions3, None )
226226 val equivalence3 = new EquivalentExpressions
227227 equivalence3.addExprTree(caseWhenExpr3)
228- assert(equivalence3.getAllEquivalentExprs.count(_.size == 2 ) == 0 )
228+ assert(equivalence3.getAllEquivalentExprs() .count(_.size == 2 ) == 0 )
229229 }
230230
231231 test(" Children of conditional expressions: Coalesce" ) {
@@ -240,8 +240,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
240240 equivalence1.addExprTree(coalesceExpr1)
241241
242242 // `add2` is repeatedly in all conditions.
243- assert(equivalence1.getAllEquivalentExprs.count(_.size == 2 ) == 1 )
244- assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2 ).head == Seq (add2, add2))
243+ assert(equivalence1.getAllEquivalentExprs() .count(_.size == 2 ) == 1 )
244+ assert(equivalence1.getAllEquivalentExprs() .filter(_.size == 2 ).head == Seq (add2, add2))
245245
246246 // Negative case. `add1` and `add2` both are not used in all branches.
247247 val conditions2 = GreaterThan (add1, Literal (3 )) ::
@@ -252,7 +252,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
252252 val equivalence2 = new EquivalentExpressions
253253 equivalence2.addExprTree(coalesceExpr2)
254254
255- assert(equivalence2.getAllEquivalentExprs.count(_.size == 2 ) == 0 )
255+ assert(equivalence2.getAllEquivalentExprs() .count(_.size == 2 ) == 0 )
256256 }
257257
258258 test(" SPARK-34723: Correct parameter type for subexpression elimination under whole-stage" ) {
@@ -309,6 +309,35 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
309309 CodeGenerator .compile(code)
310310 }
311311 }
312+
313+ test(" SPARK-35439: Children subexpr should come first than parent subexpr" ) {
314+ val add = Add (Literal (1 ), Literal (2 ))
315+
316+ val equivalence1 = new EquivalentExpressions
317+
318+ equivalence1.addExprTree(add)
319+ assert(equivalence1.getAllEquivalentExprs().head === Seq (add))
320+
321+ equivalence1.addExprTree(Add (Literal (3 ), add))
322+ assert(equivalence1.getAllEquivalentExprs() ===
323+ Seq (Seq (add, add), Seq (Add (Literal (3 ), add))))
324+
325+ equivalence1.addExprTree(Add (Literal (3 ), add))
326+ assert(equivalence1.getAllEquivalentExprs() ===
327+ Seq (Seq (add, add), Seq (Add (Literal (3 ), add), Add (Literal (3 ), add))))
328+
329+ val equivalence2 = new EquivalentExpressions
330+
331+ equivalence2.addExprTree(Add (Literal (3 ), add))
332+ assert(equivalence2.getAllEquivalentExprs() === Seq (Seq (add), Seq (Add (Literal (3 ), add))))
333+
334+ equivalence2.addExprTree(add)
335+ assert(equivalence2.getAllEquivalentExprs() === Seq (Seq (add, add), Seq (Add (Literal (3 ), add))))
336+
337+ equivalence2.addExprTree(Add (Literal (3 ), add))
338+ assert(equivalence2.getAllEquivalentExprs() ===
339+ Seq (Seq (add, add), Seq (Add (Literal (3 ), add), Add (Literal (3 ), add))))
340+ }
312341}
313342
314343case class CodegenFallbackExpression (child : Expression )
0 commit comments