@@ -61,15 +61,22 @@ class EquivalentExpressions(
6161 private def updateExprInMap (
6262 expr : Expression ,
6363 map : mutable.HashMap [ExpressionEquals , ExpressionStats ],
64- useCount : Int = 1 ): Boolean = {
64+ useCount : Int = 1 ,
65+ conditional : Boolean = false ): Boolean = {
6566 if (expr.deterministic) {
6667 val wrapper = ExpressionEquals (expr)
6768 map.get(wrapper) match {
6869 case Some (stats) =>
69- stats.useCount += useCount
70- if (stats.useCount > 0 ) {
70+ val count = if (conditional) {
71+ stats.conditionalUseCount += useCount
72+ stats.conditionalUseCount
73+ } else {
74+ stats.useCount += useCount
75+ stats.useCount
76+ }
77+ if (count > 0 ) {
7178 true
72- } else if (stats.useCount == 0 ) {
79+ } else if (count == 0 ) {
7380 map -= wrapper
7481 false
7582 } else {
@@ -79,7 +86,12 @@ class EquivalentExpressions(
7986 }
8087 case _ =>
8188 if (useCount > 0 ) {
82- map.put(wrapper, ExpressionStats (expr)(useCount))
89+ val stats = if (conditional) {
90+ ExpressionStats (expr)(useCount = 0 , conditionalUseCount = useCount)
91+ } else {
92+ ExpressionStats (expr)(useCount)
93+ }
94+ map.put(wrapper, stats)
8395 }
8496 false
8597 }
@@ -89,44 +101,47 @@ class EquivalentExpressions(
89101 }
90102
91103 /**
92- * Adds or removes only expressions which are common in each of given expressions, in a recursive
104+ * Returns a list of expressions which are common in each of given expressions, in a recursive
93105 * way.
94106 * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
95- * expression `(c + 1)` will be added into `equivalenceMap` .
107+ * expression `(c + 1)` will be returned .
96108 *
97109 * Note that as we don't know in advance if any child node of an expression will be common across
98110 * all given expressions, we compute local equivalence maps for all given expressions and filter
99111 * only the common nodes.
100- * Those common nodes are then removed from the local map and added to the final map of
112+ * Those common nodes are then removed from the local map and added to the final list of
101113 * expressions.
114+ *
115+ * Conditional expressions are not considered because we are simply looking for expressions
116+ * evaluated once in each parent expression.
102117 */
103- private def updateCommonExprs (
104- exprs : Seq [Expression ],
105- map : mutable.HashMap [ExpressionEquals , ExpressionStats ],
106- useCount : Int ): Unit = {
118+ private def getCommonExprs (exprs : Seq [Expression ]): Seq [ExpressionEquals ] = {
107119 assert(exprs.length > 1 )
108120 var localEquivalenceMap = mutable.HashMap .empty[ExpressionEquals , ExpressionStats ]
109- updateExprTree(exprs.head, localEquivalenceMap)
121+ updateExprTree(exprs.head, localEquivalenceMap, conditionalsEnabled = false )
110122
111123 exprs.tail.foreach { expr =>
112124 val otherLocalEquivalenceMap = mutable.HashMap .empty[ExpressionEquals , ExpressionStats ]
113- updateExprTree(expr, otherLocalEquivalenceMap)
125+ updateExprTree(expr, otherLocalEquivalenceMap, conditionalsEnabled = false )
114126 localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
115127 otherLocalEquivalenceMap.contains(key)
116128 }
117129 }
118130
131+ val commonExpressions = mutable.ListBuffer .empty[ExpressionEquals ]
132+
119133 // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`.
120134 // The remaining highest expression in `localEquivalenceMap` is also common expression so loop
121135 // until `localEquivalenceMap` is not empty.
122136 var statsOption = Some (localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
123137 while (statsOption.nonEmpty) {
124138 val stats = statsOption.get
125- updateExprTree(stats.expr, localEquivalenceMap, - stats.useCount)
126- updateExprTree (stats.expr, map, useCount )
139+ updateExprTree(stats.expr, localEquivalenceMap, - stats.useCount, conditionalsEnabled = false )
140+ commonExpressions += ExpressionEquals (stats.expr)
127141
128142 statsOption = Some (localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
129143 }
144+ commonExpressions.toSeq
130145 }
131146
132147 private def skipForShortcut (expr : Expression ): Expression = {
@@ -143,21 +158,17 @@ class EquivalentExpressions(
143158 }
144159 }
145160
146- // There are some special expressions that we should not recurse into all of its children.
147- // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
148- // 2. ConditionalExpression: use its children that will always be evaluated.
149- private def childrenToRecurse (expr : Expression ): Seq [Expression ] = expr match {
150- case _ : CodegenFallback => Nil
151- case c : ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
152- case other => skipForShortcut(other).children
153- }
154-
155- // For some special expressions we cannot just recurse into all of its children, but we can
156- // recursively add the common expressions shared between all of its children.
157- private def commonChildrenToRecurse (expr : Expression ): Seq [Seq [Expression ]] = expr match {
158- case _ : CodegenFallback => Nil
159- case c : ConditionalExpression => c.branchGroups
160- case _ => Nil
161+ /**
162+ * There are some expressions that need special handling:
163+ * 1. CodegenFallback: It's children will not be used to generate code (call eval() instead).
164+ * 2. ConditionalExpression: use its children that will always be evaluated.
165+ */
166+ private def childrenToRecurse (expr : Expression ): RecurseChildren = expr match {
167+ case _ : CodegenFallback => RecurseChildren (Nil )
168+ case c : ConditionalExpression =>
169+ RecurseChildren (c.alwaysEvaluatedInputs.map(skipForShortcut), c.branchGroups,
170+ c.conditionallyEvaluatedInputs)
171+ case other => RecurseChildren (skipForShortcut(other).children)
161172 }
162173
163174 private def supportedExpression (e : Expression ): Boolean = {
@@ -184,13 +195,48 @@ class EquivalentExpressions(
184195 private def updateExprTree (
185196 expr : Expression ,
186197 map : mutable.HashMap [ExpressionEquals , ExpressionStats ] = equivalenceMap,
187- useCount : Int = 1 ): Unit = {
188- val skip = useCount == 0 || expr.isInstanceOf [LeafExpression ]
198+ useCount : Int = 1 ,
199+ conditionalsEnabled : Boolean = SQLConf .get.subexpressionEliminationConditionalsEnabled,
200+ conditional : Boolean = false ,
201+ skipExpressions : Set [ExpressionEquals ] = Set .empty[ExpressionEquals ]
202+ ): Unit = {
203+ val skip = useCount == 0 ||
204+ expr.isInstanceOf [LeafExpression ] ||
205+ skipExpressions.contains(ExpressionEquals (expr))
189206
190- if (! skip && ! updateExprInMap(expr, map, useCount)) {
207+ if (! skip && ! updateExprInMap(expr, map, useCount, conditional )) {
191208 val uc = useCount.sign
192- childrenToRecurse(expr).foreach(updateExprTree(_, map, uc))
193- commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc))
209+ val recurseChildren = childrenToRecurse(expr)
210+ recurseChildren.alwaysChildren.foreach { child =>
211+ updateExprTree(child, map, uc, conditionalsEnabled, conditional, skipExpressions)
212+ }
213+
214+ /**
215+ * If the `commonExpressions` already appears in the equivalence map, calling
216+ * `updateExprTree` will increase the `useCount` and mark it as a common subexpression.
217+ * Otherwise, `updateExprTree` will recursively add `commonExpressions` and its descendant to
218+ * the equivalence map, in case they also appear in other places. For example,
219+ * `If(a + b > 1, a + b + c, a + b + c)`, `a + b` also appears in the condition and should
220+ * be treated as common subexpression.
221+ */
222+ val commonExpressions = recurseChildren.commonChildren.flatMap { exprs =>
223+ if (exprs.nonEmpty) {
224+ getCommonExprs(exprs)
225+ } else {
226+ Nil
227+ }
228+ }
229+ commonExpressions.foreach { ce =>
230+ updateExprTree(ce.e, map, uc, conditionalsEnabled, conditional, skipExpressions)
231+ }
232+
233+ if (conditionalsEnabled) {
234+ // Add all conditional expressions, skipping those that were already counted as common
235+ // expressions.
236+ recurseChildren.conditionalChildren.foreach { cc =>
237+ updateExprTree(cc, map, uc, true , true , commonExpressions.toSet)
238+ }
239+ }
194240 }
195241 }
196242
@@ -208,7 +254,7 @@ class EquivalentExpressions(
208254
209255 // Exposed for testing.
210256 private [sql] def getAllExprStates (count : Int = 0 ): Seq [ExpressionStats ] = {
211- equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2)
257+ equivalenceMap.filter(_._2.getUseCount() > count).toSeq.sortBy(_._1.height).map(_._2)
212258 }
213259
214260 /**
@@ -225,8 +271,11 @@ class EquivalentExpressions(
225271 def debugString (all : Boolean = false ): String = {
226272 val sb = new java.lang.StringBuilder ()
227273 sb.append(" Equivalent expressions:\n " )
228- equivalenceMap.values.filter(stats => all || stats.useCount > 1 ).foreach { stats =>
229- sb.append(" " ).append(s " ${stats.expr}: useCount = ${stats.useCount}" ).append('\n ' )
274+ equivalenceMap.values.filter(stats => all || stats.getUseCount() > 1 ).foreach { stats =>
275+ sb.append(" " )
276+ .append(s " ${stats.expr}: useCount = ${stats.useCount} " )
277+ .append(s " conditionalUseCount = ${stats.conditionalUseCount}" )
278+ .append('\n ' )
230279 }
231280 sb.toString()
232281 }
@@ -255,4 +304,32 @@ case class ExpressionEquals(e: Expression) {
255304 * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
256305 * useCount in this wrapper in-place.
257306 */
258- case class ExpressionStats (expr : Expression )(var useCount : Int )
307+ case class ExpressionStats (expr : Expression )(
308+ var useCount : Int = 1 ,
309+ var conditionalUseCount : Int = 0 ) {
310+ def getUseCount (): Int = if (useCount > 0 ) {
311+ useCount + conditionalUseCount
312+ } else {
313+ 0
314+ }
315+ }
316+
317+ /**
318+ * A wrapper for the different types of children of expressions.
319+ *
320+ * `alwaysChildren` are child expressions that will always be evaluated and should be considered
321+ * for subexpressions.
322+ *
323+ * `commonChildren` are children such that each of the children might be evaluated, but at last once
324+ * will definitely be evaluated. If there are any common expressions among them, those expressions
325+ * will definitely be evaluated and should be considered for subexpressions.
326+ *
327+ * `conditionalChildren` are children that are conditionally evaluated, such as in If, CaseWhen,
328+ * or Coalesce expressions, and should only be considered for subexpressions if they are evaluated
329+ * non-conditionally elsewhere.
330+ */
331+ case class RecurseChildren (
332+ alwaysChildren : Seq [Expression ],
333+ commonChildren : Seq [Seq [Expression ]] = Nil ,
334+ conditionalChildren : Seq [Expression ] = Nil
335+ )
0 commit comments