Skip to content

Commit 63accf8

Browse files
committed
Evaluate common subexpression like lazy variable with a function approach.
1 parent 3f62e1b commit 63accf8

2 files changed

Lines changed: 106 additions & 12 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ case class SubExprEliminationState(isNull: String, value: String)
6767
/**
6868
* Codes and common subexpressions mapping used for subexpression elimination.
6969
*
70-
* @param codes Strings representing the codes that evaluate common subexpressions.
70+
* @param codes Strings representing the codes that reset the initialization status of
71+
* common subexpression evaluation.
7172
* @param states Foreach expression that is participating in subexpression elimination,
7273
* the state to use.
7374
*/
@@ -680,6 +681,47 @@ class CodegenContext {
680681
genCodes
681682
}
682683

684+
/**
685+
* A private helper function used to construct the parameter list for subexpression elimination
686+
* evaluation functions
687+
*
688+
* @param expression The subexpression to evaluate.
689+
* @param caller Indicating to construct parameter list for function caller.
690+
*/
691+
private def genFunctionParamsListForSubExprEliminate(
692+
expression: Expression,
693+
caller: Boolean): String = {
694+
val boundRefs = expression.collect {
695+
case b: BoundReference => b
696+
}.distinct
697+
if (currentVars == null) {
698+
if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW"
699+
} else {
700+
val boundRefsInCurrentVars = boundRefs.filter(b => currentVars(b.ordinal) != null)
701+
val currentVarsParams = boundRefsInCurrentVars.map { bound =>
702+
val paramType = javaType(bound.dataType)
703+
val variable = currentVars(bound.ordinal).value
704+
val isNull = currentVars(bound.ordinal).isNull
705+
if (caller) {
706+
if (isNull == "false") variable else s"$variable, $isNull"
707+
} else {
708+
if (isNull == "false") {
709+
s"$paramType $variable"
710+
} else {
711+
s"$paramType $variable, boolean $isNull"
712+
}
713+
}
714+
}
715+
716+
if (boundRefsInCurrentVars.size == boundRefs.size) {
717+
currentVarsParams.mkString(", ")
718+
} else {
719+
val rowParam = if (caller) INPUT_ROW else s"InternalRow $INPUT_ROW"
720+
(Seq(rowParam) ++ currentVarsParams).mkString(", ")
721+
}
722+
}
723+
}
724+
683725
/**
684726
* Checks and sets up the state and codegen for subexpression elimination. This finds the
685727
* common subexpressions, generates the code snippets that evaluate those expressions and
@@ -700,11 +742,63 @@ class CodegenContext {
700742
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
701743
val codes = commonExprs.map { e =>
702744
val expr = e.head
745+
val fnName = freshName("evalSubExpr")
746+
val isNull = s"${fnName}IsNull"
747+
val value = s"${fnName}Value"
748+
val isInitialized = s"${fnName}IsInitialized"
749+
750+
val functionParams = genFunctionParamsListForSubExprEliminate(expr, false)
751+
val callerParams = genFunctionParamsListForSubExprEliminate(expr, true)
752+
703753
// Generate the code for this expression tree.
704754
val code = expr.genCode(this)
705-
val state = SubExprEliminationState(code.isNull, code.value)
755+
val returnType = javaType(expr.dataType)
756+
val fn =
757+
s"""
758+
|private void $fnName($functionParams) {
759+
| ${code.code.trim}
760+
| $isNull = ${code.isNull};
761+
| $value = ${code.value};
762+
| $isInitialized = true;
763+
|}
764+
""".stripMargin
765+
766+
val valueFnName = s"${fnName}ForValue"
767+
val valueFn =
768+
s"""
769+
|private $returnType $valueFnName($functionParams) {
770+
| if (!$isInitialized) {
771+
| $fnName($callerParams);
772+
| }
773+
| return $value;
774+
|}
775+
""".stripMargin
776+
777+
val isNullFnName = s"${fnName}ForIsNull"
778+
val isNullFn =
779+
s"""
780+
|private boolean $isNullFnName($functionParams) {
781+
| if (!$isInitialized) {
782+
| $fnName($callerParams);
783+
| }
784+
| return $isNull;
785+
|}
786+
""".stripMargin
787+
788+
addNewFunction(fnName, fn)
789+
addNewFunction(valueFnName, valueFn)
790+
addNewFunction(isNullFnName, isNullFn)
791+
792+
addMutableState("boolean", isNull, s"$isNull = false;")
793+
addMutableState("boolean", isInitialized, s"$isInitialized = false;")
794+
addMutableState(returnType, value, s"$value = ${defaultValue(expr.dataType)};")
795+
796+
val state = SubExprEliminationState(
797+
isNull = s"$isNullFnName($callerParams)",
798+
value = s"$valueFnName($callerParams)")
799+
706800
e.foreach(subExprEliminationExprs.put(_, state))
707-
code.code.trim
801+
s"$isInitialized = false;"
708802
}
709803
SubExprCodes(codes, subExprEliminationExprs.toMap)
710804
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ case class HashAggregateExec(
253253
ctx.currentVars = bufVars ++ input
254254
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
255255
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
256-
val effectiveCodes = subExprs.codes.mkString("\n")
256+
val resetSubExprEvaluation = subExprs.codes.mkString("\n")
257257
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
258258
boundUpdateExpr.map(_.genCode(ctx))
259259
}
@@ -266,8 +266,8 @@ case class HashAggregateExec(
266266
}
267267
s"""
268268
| // do aggregate
269-
| // common sub-expressions
270-
| $effectiveCodes
269+
| // reset the initialization status for common sub-expressions
270+
| $resetSubExprEvaluation
271271
| // evaluate aggregate function
272272
| ${evaluateVariables(aggVals)}
273273
| // update aggregation buffer
@@ -758,7 +758,7 @@ case class HashAggregateExec(
758758
ctx.INPUT_ROW = fastRowBuffer
759759
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
760760
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
761-
val effectiveCodes = subExprs.codes.mkString("\n")
761+
val resetSubExprEvaluation = subExprs.codes.mkString("\n")
762762
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
763763
boundUpdateExpr.map(_.genCode(ctx))
764764
}
@@ -768,8 +768,8 @@ case class HashAggregateExec(
768768
}
769769
Option(
770770
s"""
771-
|// common sub-expressions
772-
|$effectiveCodes
771+
|// reset the initialization status for common sub-expressions
772+
|$resetSubExprEvaluation
773773
|// evaluate aggregate function
774774
|${evaluateVariables(fastRowEvals)}
775775
|// update fast row
@@ -814,7 +814,7 @@ case class HashAggregateExec(
814814
ctx.INPUT_ROW = unsafeRowBuffer
815815
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
816816
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
817-
val effectiveCodes = subExprs.codes.mkString("\n")
817+
val resetSubExprEvaluation = subExprs.codes.mkString("\n")
818818
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
819819
boundUpdateExpr.map(_.genCode(ctx))
820820
}
@@ -823,8 +823,8 @@ case class HashAggregateExec(
823823
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
824824
}
825825
s"""
826-
|// common sub-expressions
827-
|$effectiveCodes
826+
|// reset the initialization status for common sub-expressions
827+
|$resetSubExprEvaluation
828828
|// evaluate aggregate function
829829
|${evaluateVariables(unsafeRowBufferEvals)}
830830
|// update unsafe row buffer

0 commit comments

Comments
 (0)