Skip to content

Commit a086dd5

Browse files
committed
Track conditionally evaluated expressions to resolve as subexpressions for cases they are already being evaluated
1 parent 38bc351 commit a086dd5

6 files changed

Lines changed: 243 additions & 96 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 117 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,22 @@ class EquivalentExpressions(
6161
private def updateExprInMap(
6262
expr: Expression,
6363
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
64-
useCount: Int = 1): Boolean = {
64+
useCount: Int = 1,
65+
conditional: Boolean = false): Boolean = {
6566
if (expr.deterministic) {
6667
val wrapper = ExpressionEquals(expr)
6768
map.get(wrapper) match {
6869
case Some(stats) =>
69-
stats.useCount += useCount
70-
if (stats.useCount > 0) {
70+
val count = if (conditional) {
71+
stats.conditionalUseCount += useCount
72+
stats.conditionalUseCount
73+
} else {
74+
stats.useCount += useCount
75+
stats.useCount
76+
}
77+
if (count > 0) {
7178
true
72-
} else if (stats.useCount == 0) {
79+
} else if (count == 0) {
7380
map -= wrapper
7481
false
7582
} else {
@@ -79,7 +86,12 @@ class EquivalentExpressions(
7986
}
8087
case _ =>
8188
if (useCount > 0) {
82-
map.put(wrapper, ExpressionStats(expr)(useCount))
89+
val stats = if (conditional) {
90+
ExpressionStats(expr)(useCount = 0, conditionalUseCount = useCount)
91+
} else {
92+
ExpressionStats(expr)(useCount)
93+
}
94+
map.put(wrapper, stats)
8395
}
8496
false
8597
}
@@ -89,44 +101,47 @@ class EquivalentExpressions(
89101
}
90102

91103
/**
92-
* Adds or removes only expressions which are common in each of given expressions, in a recursive
104+
* Returns a list of expressions which are common in each of given expressions, in a recursive
93105
* way.
94106
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
95-
* expression `(c + 1)` will be added into `equivalenceMap`.
107+
* expression `(c + 1)` will be returned.
96108
*
97109
* Note that as we don't know in advance if any child node of an expression will be common across
98110
* all given expressions, we compute local equivalence maps for all given expressions and filter
99111
* only the common nodes.
100-
* Those common nodes are then removed from the local map and added to the final map of
112+
* Those common nodes are then removed from the local map and added to the final list of
101113
* expressions.
114+
*
115+
* Conditional expressions are not considered because we are simply looking for expressions
116+
* evaluated once in each parent expression.
102117
*/
103-
private def updateCommonExprs(
104-
exprs: Seq[Expression],
105-
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
106-
useCount: Int): Unit = {
118+
private def getCommonExprs(exprs: Seq[Expression]): Seq[ExpressionEquals] = {
107119
assert(exprs.length > 1)
108120
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
109-
updateExprTree(exprs.head, localEquivalenceMap)
121+
updateExprTree(exprs.head, localEquivalenceMap, conditionalsEnabled = false)
110122

111123
exprs.tail.foreach { expr =>
112124
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
113-
updateExprTree(expr, otherLocalEquivalenceMap)
125+
updateExprTree(expr, otherLocalEquivalenceMap, conditionalsEnabled = false)
114126
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
115127
otherLocalEquivalenceMap.contains(key)
116128
}
117129
}
118130

131+
val commonExpressions = mutable.ListBuffer.empty[ExpressionEquals]
132+
119133
// Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`.
120134
// The remaining highest expression in `localEquivalenceMap` is also common expression so loop
121135
// until `localEquivalenceMap` is not empty.
122136
var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
123137
while (statsOption.nonEmpty) {
124138
val stats = statsOption.get
125-
updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
126-
updateExprTree(stats.expr, map, useCount)
139+
updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount, conditionalsEnabled = false)
140+
commonExpressions += ExpressionEquals(stats.expr)
127141

128142
statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
129143
}
144+
commonExpressions.toSeq
130145
}
131146

132147
private def skipForShortcut(expr: Expression): Expression = {
@@ -143,21 +158,17 @@ class EquivalentExpressions(
143158
}
144159
}
145160

146-
// There are some special expressions that we should not recurse into all of its children.
147-
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
148-
// 2. ConditionalExpression: use its children that will always be evaluated.
149-
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
150-
case _: CodegenFallback => Nil
151-
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
152-
case other => skipForShortcut(other).children
153-
}
154-
155-
// For some special expressions we cannot just recurse into all of its children, but we can
156-
// recursively add the common expressions shared between all of its children.
157-
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
158-
case _: CodegenFallback => Nil
159-
case c: ConditionalExpression => c.branchGroups
160-
case _ => Nil
161+
/**
162+
* There are some expressions that need special handling:
163+
* 1. CodegenFallback: It's children will not be used to generate code (call eval() instead).
164+
* 2. ConditionalExpression: use its children that will always be evaluated.
165+
*/
166+
private def childrenToRecurse(expr: Expression): RecurseChildren = expr match {
167+
case _: CodegenFallback => RecurseChildren(Nil)
168+
case c: ConditionalExpression =>
169+
RecurseChildren(c.alwaysEvaluatedInputs.map(skipForShortcut), c.branchGroups,
170+
c.conditionallyEvaluatedInputs)
171+
case other => RecurseChildren(skipForShortcut(other).children)
161172
}
162173

163174
private def supportedExpression(e: Expression): Boolean = {
@@ -184,13 +195,48 @@ class EquivalentExpressions(
184195
private def updateExprTree(
185196
expr: Expression,
186197
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap,
187-
useCount: Int = 1): Unit = {
188-
val skip = useCount == 0 || expr.isInstanceOf[LeafExpression]
198+
useCount: Int = 1,
199+
conditionalsEnabled: Boolean = SQLConf.get.subexpressionEliminationConditionalsEnabled,
200+
conditional: Boolean = false,
201+
skipExpressions: Set[ExpressionEquals] = Set.empty[ExpressionEquals]
202+
): Unit = {
203+
val skip = useCount == 0 ||
204+
expr.isInstanceOf[LeafExpression] ||
205+
skipExpressions.contains(ExpressionEquals(expr))
189206

190-
if (!skip && !updateExprInMap(expr, map, useCount)) {
207+
if (!skip && !updateExprInMap(expr, map, useCount, conditional)) {
191208
val uc = useCount.sign
192-
childrenToRecurse(expr).foreach(updateExprTree(_, map, uc))
193-
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc))
209+
val recurseChildren = childrenToRecurse(expr)
210+
recurseChildren.alwaysChildren.foreach { child =>
211+
updateExprTree(child, map, uc, conditionalsEnabled, conditional, skipExpressions)
212+
}
213+
214+
/**
215+
* If the `commonExpressions` already appears in the equivalence map, calling
216+
* `updateExprTree` will increase the `useCount` and mark it as a common subexpression.
217+
* Otherwise, `updateExprTree` will recursively add `commonExpressions` and its descendant to
218+
* the equivalence map, in case they also appear in other places. For example,
219+
* `If(a + b > 1, a + b + c, a + b + c)`, `a + b` also appears in the condition and should
220+
* be treated as common subexpression.
221+
*/
222+
val commonExpressions = recurseChildren.commonChildren.flatMap { exprs =>
223+
if (exprs.nonEmpty) {
224+
getCommonExprs(exprs)
225+
} else {
226+
Nil
227+
}
228+
}
229+
commonExpressions.foreach { ce =>
230+
updateExprTree(ce.e, map, uc, conditionalsEnabled, conditional, skipExpressions)
231+
}
232+
233+
if (conditionalsEnabled) {
234+
// Add all conditional expressions, skipping those that were already counted as common
235+
// expressions.
236+
recurseChildren.conditionalChildren.foreach { cc =>
237+
updateExprTree(cc, map, uc, true, true, commonExpressions.toSet)
238+
}
239+
}
194240
}
195241
}
196242

@@ -208,7 +254,7 @@ class EquivalentExpressions(
208254

209255
// Exposed for testing.
210256
private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
211-
equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2)
257+
equivalenceMap.filter(_._2.getUseCount() > count).toSeq.sortBy(_._1.height).map(_._2)
212258
}
213259

214260
/**
@@ -225,8 +271,11 @@ class EquivalentExpressions(
225271
def debugString(all: Boolean = false): String = {
226272
val sb = new java.lang.StringBuilder()
227273
sb.append("Equivalent expressions:\n")
228-
equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats =>
229-
sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n')
274+
equivalenceMap.values.filter(stats => all || stats.getUseCount() > 1).foreach { stats =>
275+
sb.append(" ")
276+
.append(s"${stats.expr}: useCount = ${stats.useCount} ")
277+
.append(s"conditionalUseCount = ${stats.conditionalUseCount}")
278+
.append('\n')
230279
}
231280
sb.toString()
232281
}
@@ -255,4 +304,32 @@ case class ExpressionEquals(e: Expression) {
255304
* Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
256305
* useCount in this wrapper in-place.
257306
*/
258-
case class ExpressionStats(expr: Expression)(var useCount: Int)
307+
case class ExpressionStats(expr: Expression)(
308+
var useCount: Int = 1,
309+
var conditionalUseCount: Int = 0) {
310+
def getUseCount(): Int = if (useCount > 0) {
311+
useCount + conditionalUseCount
312+
} else {
313+
0
314+
}
315+
}
316+
317+
/**
318+
* A wrapper for the different types of children of expressions.
319+
*
320+
* `alwaysChildren` are child expressions that will always be evaluated and should be considered
321+
* for subexpressions.
322+
*
323+
* `commonChildren` are children such that each of the children might be evaluated, but at last once
324+
* will definitely be evaluated. If there are any common expressions among them, those expressions
325+
* will definitely be evaluated and should be considered for subexpressions.
326+
*
327+
* `conditionalChildren` are children that are conditionally evaluated, such as in If, CaseWhen,
328+
* or Coalesce expressions, and should only be considered for subexpressions if they are evaluated
329+
* non-conditionally elsewhere.
330+
*/
331+
case class RecurseChildren(
332+
alwaysChildren: Seq[Expression],
333+
commonChildren: Seq[Seq[Expression]] = Nil,
334+
conditionalChildren: Seq[Expression] = Nil
335+
)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,12 @@ trait ConditionalExpression extends Expression {
572572
* so that we can eagerly evaluate the common expressions of a group.
573573
*/
574574
def branchGroups: Seq[Seq[Expression]]
575+
576+
/**
577+
* Returns children expressions which are conditionally evaluated. If the same expression
578+
* will be always evaluated elsewhere, we can make it a subexpression.
579+
*/
580+
def conditionallyEvaluatedInputs: Seq[Expression]
575581
}
576582

577583
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
6565

6666
override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue))
6767

68+
override def conditionallyEvaluatedInputs: Seq[Expression] = Seq(trueValue, falseValue)
69+
6870
final override val nodePatterns : Seq[TreePattern] = Seq(IF)
6971

7072
override def checkInputDataTypes(): TypeCheckResult = {
@@ -241,29 +243,12 @@ case class CaseWhen(
241243
}
242244

243245
override def branchGroups: Seq[Seq[Expression]] = {
244-
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
245-
// because a subexpression in conditions will be run no matter which condition is matched
246-
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
247-
// a subexpression among values doesn't need to be in conditions because no matter which
248-
// condition is true, it will be evaluated.
249-
val conditions = if (branches.length > 1) {
250-
branches.map(_._1)
251-
} else {
252-
// If there is only one branch, the first condition is already covered by
253-
// `alwaysEvaluatedInputs` and we should exclude it here.
254-
Nil
255-
}
256-
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
257-
// the elseValue.
258-
val values = if (elseValue.nonEmpty) {
259-
branches.map(_._2) ++ elseValue
260-
} else {
261-
Nil
262-
}
263-
264-
Seq(conditions, values)
246+
// If there's an else value then we will definitely evaluate at least one branch value
247+
if (elseValue.isDefined) Seq(branches.map(_._2) ++ elseValue) else Nil
265248
}
266249

250+
override def conditionallyEvaluatedInputs: Seq[Expression] = children.tail
251+
267252
override def eval(input: InternalRow): Any = {
268253
var i = 0
269254
val size = branches.size

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,9 @@ case class Coalesce(children: Seq[Expression])
7575
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
7676
}
7777

78-
override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
79-
// If there is only one child, the first child is already covered by
80-
// `alwaysEvaluatedInputs` and we should exclude it here.
81-
Seq(children)
82-
} else {
83-
Nil
84-
}
78+
override def branchGroups: Seq[Seq[Expression]] = Nil
79+
80+
override def conditionallyEvaluatedInputs: Seq[Expression] = children.tail
8581

8682
override def eval(input: InternalRow): Any = {
8783
var result: Any = null
@@ -348,7 +344,9 @@ case class NaNvl(left: Expression, right: Expression)
348344
copy(left = alwaysEvaluatedInputs.head)
349345
}
350346

351-
override def branchGroups: Seq[Seq[Expression]] = Seq(children)
347+
override def branchGroups: Seq[Seq[Expression]] = Nil
348+
349+
override def conditionallyEvaluatedInputs: Seq[Expression] = right :: Nil
352350

353351
override def eval(input: InternalRow): Any = {
354352
val value = left.eval(input)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,14 @@ object SQLConf {
12511251
.booleanConf
12521252
.createWithDefault(false)
12531253

1254+
val SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED =
1255+
buildConf("spark.sql.subexpressionElimination.conditionals.enabled")
1256+
.internal()
1257+
.doc("When true, common conditional subexpressions will be eliminated.")
1258+
.version("4.0.0")
1259+
.booleanConf
1260+
.createWithDefault(false)
1261+
12541262
val CASE_SENSITIVE = buildConf(SqlApiConfHelper.CASE_SENSITIVE_KEY)
12551263
.internal()
12561264
.doc("Whether the query analyzer should be case sensitive or not. " +
@@ -7306,6 +7314,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
73067314
def subexpressionEliminationSkipForShotcutExpr: Boolean =
73077315
getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR)
73087316

7317+
def subexpressionEliminationConditionalsEnabled: Boolean =
7318+
getConf(SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED)
7319+
73097320
def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
73107321

73117322
def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)

0 commit comments

Comments
 (0)