Skip to content

Commit 4574b30

Browse files
committed
Simplify SubExprEliminationState and add more doc.
1 parent c774797 commit 4574b30

4 files changed

Lines changed: 63 additions & 37 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ abstract class Expression extends TreeNode[Expression] {
139139
ctx.subExprEliminationExprs.get(this).map { subExprState =>
140140
// This expression is repeated which means that the code to evaluate it has already been added
141141
// as a function before. In that case, we just re-use it.
142-
ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value)
142+
ExprCode(
143+
ctx.registerComment(this.toString),
144+
subExprState.eval.isNull,
145+
subExprState.eval.value)
143146
}.getOrElse {
144147
val isNull = ctx.freshName("isNull")
145148
val value = ctx.freshName("value")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,16 @@ object ExprCode {
7676
/**
7777
* State used for subexpression elimination.
7878
*
79-
* @param code The sequence of statements required to evaluate the subexpression.
80-
* @param isNull A term that holds a boolean value representing whether the expression evaluated
81-
* to null.
82-
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
83-
* is set to `true`.
84-
* @param childrenSubExprs The sequence of subexpressions as the children expressions. Before
85-
* evaluating this subexpression, we should evaluate all children
86-
* subexpressions first. This is used if we want to selectively evaluate
87-
* particular subexpressions, instead of all at once. In the case, we need
88-
* to make sure we evaluate all children subexpressions too.
79+
* @param eval The source code for evaluating the subexpression.
80+
* @param children The sequence of subexpressions as the children expressions. Before
81+
* evaluating this subexpression, we should evaluate all children
82+
* subexpressions first. This is used if we want to selectively evaluate
83+
* particular subexpressions, instead of all at once. In the case, we need
84+
* to make sure we evaluate all children subexpressions too.
8985
*/
9086
case class SubExprEliminationState(
91-
var code: Block,
92-
isNull: ExprValue,
93-
value: ExprValue,
94-
childrenSubExprs: Seq[SubExprEliminationState] = Seq.empty)
87+
eval: ExprCode,
88+
children: Seq[SubExprEliminationState] = Seq.empty)
9589

9690
/**
9791
* Codes and common subexpressions mapping used for subexpression elimination.
@@ -1046,20 +1040,47 @@ class CodegenContext extends Logging {
10461040
val code = new StringBuilder()
10471041

10481042
subExprStates.foreach { state =>
1049-
val currentCode = evaluateSubExprEliminationState(state.childrenSubExprs) + "\n" + state.code
1043+
val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code
10501044
code.append(currentCode + "\n")
1051-
state.code = EmptyBlock
1045+
state.eval.code = EmptyBlock
10521046
}
10531047

10541048
code.toString()
10551049
}
10561050

10571051
/**
1058-
* Checks and sets up the state and codegen for subexpression elimination. This finds the
1059-
* common subexpressions, generates the code snippets that evaluate those expressions and
1060-
* populates the mapping of common subexpressions to the generated code snippets. The generated
1061-
* code snippets will be returned and should be inserted into generated codes before these
1062-
* common subexpressions actually are used first time.
1052+
* Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen.
1053+
*
1054+
* This finds the common subexpressions, generates the code snippets that evaluate those
1055+
* expressions and populates the mapping of common subexpressions to the generated code snippets.
1056+
*
1057+
* The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which
1058+
* contains a `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode`
1059+
* includes java source code, result variable name and is-null variable name of the subexpression.
1060+
*
1061+
* Besides, this also returns a sequences of `ExprCode` which are expression codes that need to
1062+
* be evaluated (as their input parameters) before evaluating subexpressions.
1063+
*
1064+
* To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with
1065+
* the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup
1066+
* the states to avoid duplicate evaluation.
1067+
*
1068+
* The details of subexpression generation:
1069+
* 1. Gets subexpression set. See `EquivalentExpressions`.
1070+
* 2. Generate code of subexpressions as a whole block of code (non-split case)
1071+
* 3. Check if the total length of the above block is larger than the split-threshold. If so,
1072+
* try to split it in step 4, otherwise returning the non-split code block.
1073+
* 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so,
1074+
* try to split, otherwise returning the non-split code block.
1075+
* 5. For each subexpression, generating a function and put the code into it. To evaluate the
1076+
* subexpression, just call the function.
1077+
*
1078+
* The explanation of subexpression codegen:
1079+
* 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each
1080+
* subexpression may depends on other subexpressions (children). So when generating code
1081+
* for subexpressions, we iterate over each subexpression and put the mapping between
1082+
* (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression
1083+
* evaluation, we can look for generated subexpressions and do replacement.
10631084
*/
10641085
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
10651086
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
@@ -1086,8 +1107,7 @@ class CodegenContext extends Logging {
10861107
childrenSubExprs += subExprEliminationExprs(e)
10871108
case _ =>
10881109
}
1089-
val state = SubExprEliminationState(eval.code, eval.isNull, eval.value,
1090-
childrenSubExprs.toSeq.reverse)
1110+
val state = SubExprEliminationState(eval, childrenSubExprs.toSeq.reverse)
10911111
exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
10921112
allStates += state
10931113
Seq(eval)
@@ -1105,7 +1125,7 @@ class CodegenContext extends Logging {
11051125
(inputVars.toSeq, exprCodes.toSeq)
11061126
}.unzip
11071127

1108-
val needSplit = nonSplitCode.map(_.code.length).sum > SQLConf.get.methodSplitThreshold
1128+
val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold
11091129
val (subExprsMap, exprCodes) = if (needSplit) {
11101130
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
11111131
val localSubExprEliminationExprs =
@@ -1154,7 +1174,8 @@ class CodegenContext extends Logging {
11541174

11551175
val inputVariables = inputVars.map(_.variableName).mkString(", ")
11561176
val code = code"${addNewFunction(fnName, fn)}($inputVariables);"
1157-
val state = SubExprEliminationState(code, isNull, JavaCode.global(value, expr.dataType),
1177+
val state = SubExprEliminationState(
1178+
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
11581179
childrenSubExprs.toSeq.reverse)
11591180
exprs.foreach(localSubExprEliminationExprs.put(_, state))
11601181
}
@@ -1219,9 +1240,9 @@ class CodegenContext extends Logging {
12191240
val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
12201241
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
12211242
val state = SubExprEliminationState(
1222-
code"$subExprCode",
1223-
JavaCode.isNullGlobal(isNull),
1224-
JavaCode.global(value, expr.dataType))
1243+
ExprCode(code"$subExprCode",
1244+
JavaCode.isNullGlobal(isNull),
1245+
JavaCode.global(value, expr.dataType)))
12251246
subExprEliminationExprs ++= e.map(_ -> state).toMap
12261247
}
12271248
}
@@ -1820,9 +1841,8 @@ object CodeGenerator extends Logging {
18201841
while (stack.nonEmpty) {
18211842
stack.pop() match {
18221843
case e if subExprs.contains(e) =>
1823-
val SubExprEliminationState(_, isNull, value, _) = subExprs(e)
1824-
collectLocalVariable(value)
1825-
collectLocalVariable(isNull)
1844+
collectLocalVariable(subExprs(e).eval.value)
1845+
collectLocalVariable(subExprs(e).eval.isNull)
18261846

18271847
case ref: BoundReference if ctx.currentVars != null &&
18281848
ctx.currentVars(ref.ordinal) != null =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
463463
val add1 = Add(ref, ref)
464464
val add2 = Add(add1, add1)
465465
val dummy = SubExprEliminationState(
466-
EmptyBlock,
467-
JavaCode.variable("dummy", BooleanType),
468-
JavaCode.variable("dummy", BooleanType))
466+
ExprCode(EmptyBlock,
467+
JavaCode.variable("dummy", BooleanType),
468+
JavaCode.variable("dummy", BooleanType)))
469469

470470
// raw testing of basic functionality
471471
{
472472
val ctx = new CodegenContext
473473
val e = ref.genCode(ctx)
474474
// before
475-
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(EmptyBlock, e.isNull, e.value)
475+
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(
476+
ExprCode(EmptyBlock, e.isNull, e.value))
476477
assert(ctx.subExprEliminationExprs.contains(ref))
477478
// call withSubExprEliminationExprs
478479
ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@ case class HashAggregateExec(
258258
aggBufferUpdatingExprs: Seq[Seq[Expression]],
259259
aggCodeBlocks: Seq[Block],
260260
subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
261-
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
261+
val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
262+
s.eval.value :: s.eval.isNull :: Nil
263+
}
262264
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
263265
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
264266
// we give up splitting functions if it exists in `subExprs`.

0 commit comments

Comments
 (0)