@@ -76,22 +76,16 @@ object ExprCode {
7676/**
7777 * State used for subexpression elimination.
7878 *
79- * @param code The sequence of statements required to evaluate the subexpression.
80- * @param isNull A term that holds a boolean value representing whether the expression evaluated
81- * to null.
82- * @param value A term for a value of a common sub-expression. Not valid if `isNull`
83- * is set to `true`.
84- * @param childrenSubExprs The sequence of subexpressions as the children expressions. Before
85- * evaluating this subexpression, we should evaluate all children
86- * subexpressions first. This is used if we want to selectively evaluate
87- * particular subexpressions, instead of all at once. In the case, we need
88- * to make sure we evaluate all children subexpressions too.
79+ * @param eval The source code for evaluating the subexpression.
80+ * @param children The sequence of subexpressions as the children expressions. Before
81+ * evaluating this subexpression, we should evaluate all children
82+ * subexpressions first. This is used if we want to selectively evaluate
83+ * particular subexpressions, instead of all at once. In the case, we need
84+ * to make sure we evaluate all children subexpressions too.
8985 */
9086case class SubExprEliminationState (
91- var code : Block ,
92- isNull : ExprValue ,
93- value : ExprValue ,
94- childrenSubExprs : Seq [SubExprEliminationState ] = Seq .empty)
87+ eval : ExprCode ,
88+ children : Seq [SubExprEliminationState ] = Seq .empty)
9589
9690/**
9791 * Codes and common subexpressions mapping used for subexpression elimination.
@@ -1046,20 +1040,47 @@ class CodegenContext extends Logging {
10461040 val code = new StringBuilder ()
10471041
10481042 subExprStates.foreach { state =>
1049- val currentCode = evaluateSubExprEliminationState(state.childrenSubExprs ) + " \n " + state.code
1043+ val currentCode = evaluateSubExprEliminationState(state.children ) + " \n " + state.eval .code
10501044 code.append(currentCode + " \n " )
1051- state.code = EmptyBlock
1045+ state.eval. code = EmptyBlock
10521046 }
10531047
10541048 code.toString()
10551049 }
10561050
10571051 /**
1058- * Checks and sets up the state and codegen for subexpression elimination. This finds the
1059- * common subexpressions, generates the code snippets that evaluate those expressions and
1060- * populates the mapping of common subexpressions to the generated code snippets. The generated
1061- * code snippets will be returned and should be inserted into generated codes before these
1062- * common subexpressions actually are used first time.
1052+ * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen.
1053+ *
1054+ * This finds the common subexpressions, generates the code snippets that evaluate those
1055+ * expressions and populates the mapping of common subexpressions to the generated code snippets.
1056+ *
1057+ * The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which
1058+ * contains a `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode`
1059+ * includes java source code, result variable name and is-null variable name of the subexpression.
1060+ *
1061+ * Besides, this also returns a sequences of `ExprCode` which are expression codes that need to
1062+ * be evaluated (as their input parameters) before evaluating subexpressions.
1063+ *
1064+ * To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with
1065+ * the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup
1066+ * the states to avoid duplicate evaluation.
1067+ *
1068+ * The details of subexpression generation:
1069+ * 1. Gets subexpression set. See `EquivalentExpressions`.
1070+ * 2. Generate code of subexpressions as a whole block of code (non-split case)
1071+ * 3. Check if the total length of the above block is larger than the split-threshold. If so,
1072+ * try to split it in step 4, otherwise returning the non-split code block.
1073+ * 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so,
1074+ * try to split, otherwise returning the non-split code block.
1075+ * 5. For each subexpression, generating a function and put the code into it. To evaluate the
1076+ * subexpression, just call the function.
1077+ *
1078+ * The explanation of subexpression codegen:
1079+ * 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each
1080+ * subexpression may depends on other subexpressions (children). So when generating code
1081+ * for subexpressions, we iterate over each subexpression and put the mapping between
1082+ * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression
1083+ * evaluation, we can look for generated subexpressions and do replacement.
10631084 */
10641085 def subexpressionEliminationForWholeStageCodegen (expressions : Seq [Expression ]): SubExprCodes = {
10651086 // Create a clear EquivalentExpressions and SubExprEliminationState mapping
@@ -1086,8 +1107,7 @@ class CodegenContext extends Logging {
10861107 childrenSubExprs += subExprEliminationExprs(e)
10871108 case _ =>
10881109 }
1089- val state = SubExprEliminationState (eval.code, eval.isNull, eval.value,
1090- childrenSubExprs.toSeq.reverse)
1110+ val state = SubExprEliminationState (eval, childrenSubExprs.toSeq.reverse)
10911111 exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
10921112 allStates += state
10931113 Seq (eval)
@@ -1105,7 +1125,7 @@ class CodegenContext extends Logging {
11051125 (inputVars.toSeq, exprCodes.toSeq)
11061126 }.unzip
11071127
1108- val needSplit = nonSplitCode.map(_.code.length).sum > SQLConf .get.methodSplitThreshold
1128+ val needSplit = nonSplitCode.map(_.eval. code.length).sum > SQLConf .get.methodSplitThreshold
11091129 val (subExprsMap, exprCodes) = if (needSplit) {
11101130 if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
11111131 val localSubExprEliminationExprs =
@@ -1154,7 +1174,8 @@ class CodegenContext extends Logging {
11541174
11551175 val inputVariables = inputVars.map(_.variableName).mkString(" , " )
11561176 val code = code " ${addNewFunction(fnName, fn)}( $inputVariables); "
1157- val state = SubExprEliminationState (code, isNull, JavaCode .global(value, expr.dataType),
1177+ val state = SubExprEliminationState (
1178+ ExprCode (code, isNull, JavaCode .global(value, expr.dataType)),
11581179 childrenSubExprs.toSeq.reverse)
11591180 exprs.foreach(localSubExprEliminationExprs.put(_, state))
11601181 }
@@ -1219,9 +1240,9 @@ class CodegenContext extends Logging {
12191240 val subExprCode = s " ${addNewFunction(fnName, fn)}( $INPUT_ROW); "
12201241 subexprFunctions += s " ${addNewFunction(fnName, fn)}( $INPUT_ROW); "
12211242 val state = SubExprEliminationState (
1222- code " $subExprCode" ,
1223- JavaCode .isNullGlobal(isNull),
1224- JavaCode .global(value, expr.dataType))
1243+ ExprCode ( code " $subExprCode" ,
1244+ JavaCode .isNullGlobal(isNull),
1245+ JavaCode .global(value, expr.dataType) ))
12251246 subExprEliminationExprs ++= e.map(_ -> state).toMap
12261247 }
12271248 }
@@ -1820,9 +1841,8 @@ object CodeGenerator extends Logging {
18201841 while (stack.nonEmpty) {
18211842 stack.pop() match {
18221843 case e if subExprs.contains(e) =>
1823- val SubExprEliminationState (_, isNull, value, _) = subExprs(e)
1824- collectLocalVariable(value)
1825- collectLocalVariable(isNull)
1844+ collectLocalVariable(subExprs(e).eval.value)
1845+ collectLocalVariable(subExprs(e).eval.isNull)
18261846
18271847 case ref : BoundReference if ctx.currentVars != null &&
18281848 ctx.currentVars(ref.ordinal) != null =>
0 commit comments