@@ -67,7 +67,8 @@ case class SubExprEliminationState(isNull: String, value: String)
6767/**
6868 * Codes and common subexpressions mapping used for subexpression elimination.
6969 *
70- * @param codes Strings representing the codes that evaluate common subexpressions.
70+ * @param codes Strings representing the codes that reset the initialization status of
71+ * common subexpression evaluation.
7172 * @param states Foreach expression that is participating in subexpression elimination,
7273 * the state to use.
7374 */
@@ -680,6 +681,47 @@ class CodegenContext {
680681 genCodes
681682 }
682683
684+ /**
685+ * A private helper function used to construct the parameter list for subexpression elimination
686+ * evaluation functions
687+ *
688+ * @param expression The subexpression to evaluate.
689+ * @param caller Indicating to construct parameter list for function caller.
690+ */
691+ private def genFunctionParamsListForSubExprEliminate (
692+ expression : Expression ,
693+ caller : Boolean ): String = {
694+ val boundRefs = expression.collect {
695+ case b : BoundReference => b
696+ }.distinct
697+ if (currentVars == null ) {
698+ if (caller) INPUT_ROW else s " InternalRow $INPUT_ROW"
699+ } else {
700+ val boundRefsInCurrentVars = boundRefs.filter(b => currentVars(b.ordinal) != null )
701+ val currentVarsParams = boundRefsInCurrentVars.map { bound =>
702+ val paramType = javaType(bound.dataType)
703+ val variable = currentVars(bound.ordinal).value
704+ val isNull = currentVars(bound.ordinal).isNull
705+ if (caller) {
706+ if (isNull == " false" ) variable else s " $variable, $isNull"
707+ } else {
708+ if (isNull == " false" ) {
709+ s " $paramType $variable"
710+ } else {
711+ s " $paramType $variable, boolean $isNull"
712+ }
713+ }
714+ }
715+
716+ if (boundRefsInCurrentVars.size == boundRefs.size) {
717+ currentVarsParams.mkString(" , " )
718+ } else {
719+ val rowParam = if (caller) INPUT_ROW else s " InternalRow $INPUT_ROW"
720+ (Seq (rowParam) ++ currentVarsParams).mkString(" , " )
721+ }
722+ }
723+ }
724+
683725 /**
684726 * Checks and sets up the state and codegen for subexpression elimination. This finds the
685727 * common subexpressions, generates the code snippets that evaluate those expressions and
@@ -700,11 +742,63 @@ class CodegenContext {
700742 val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1 )
701743 val codes = commonExprs.map { e =>
702744 val expr = e.head
745+ val fnName = freshName(" evalSubExpr" )
746+ val isNull = s " ${fnName}IsNull "
747+ val value = s " ${fnName}Value "
748+ val isInitialized = s " ${fnName}IsInitialized "
749+
750+ val functionParams = genFunctionParamsListForSubExprEliminate(expr, false )
751+ val callerParams = genFunctionParamsListForSubExprEliminate(expr, true )
752+
703753 // Generate the code for this expression tree.
704754 val code = expr.genCode(this )
705- val state = SubExprEliminationState (code.isNull, code.value)
755+ val returnType = javaType(expr.dataType)
756+ val fn =
757+ s """
758+ |private void $fnName( $functionParams) {
759+ | ${code.code.trim}
760+ | $isNull = ${code.isNull};
761+ | $value = ${code.value};
762+ | $isInitialized = true;
763+ |}
764+ """ .stripMargin
765+
766+ val valueFnName = s " ${fnName}ForValue "
767+ val valueFn =
768+ s """
769+ |private $returnType $valueFnName( $functionParams) {
770+ | if (! $isInitialized) {
771+ | $fnName( $callerParams);
772+ | }
773+ | return $value;
774+ |}
775+ """ .stripMargin
776+
777+ val isNullFnName = s " ${fnName}ForIsNull "
778+ val isNullFn =
779+ s """
780+ |private boolean $isNullFnName( $functionParams) {
781+ | if (! $isInitialized) {
782+ | $fnName( $callerParams);
783+ | }
784+ | return $isNull;
785+ |}
786+ """ .stripMargin
787+
788+ addNewFunction(fnName, fn)
789+ addNewFunction(valueFnName, valueFn)
790+ addNewFunction(isNullFnName, isNullFn)
791+
792+ addMutableState(" boolean" , isNull, s " $isNull = false; " )
793+ addMutableState(" boolean" , isInitialized, s " $isInitialized = false; " )
794+ addMutableState(returnType, value, s " $value = ${defaultValue(expr.dataType)}; " )
795+
796+ val state = SubExprEliminationState (
797+ isNull = s " $isNullFnName( $callerParams) " ,
798+ value = s " $valueFnName( $callerParams) " )
799+
706800 e.foreach(subExprEliminationExprs.put(_, state))
707- code.code.trim
801+ s " $isInitialized = false; "
708802 }
709803 SubExprCodes (codes, subExprEliminationExprs.toMap)
710804 }
0 commit comments