Skip to content

Commit dbf0b50

Browse files
committed
[SPARK-35560][SQL] Remove redundant subexpression evaluation in nested subexpressions
### What changes were proposed in this pull request? This patch proposes to improve subexpression evaluation under whole-stage codegen for the cases of nested subexpressions. ### Why are the changes needed? In the cases of nested subexpressions, whole-stage codegen's subexpression elimination will do redundant subexpression evaluation. We should reduce it. For example, if we have two sub-exprs: 1. `simpleUDF($"id")` 2. `functions.length(simpleUDF($"id"))` We should only evaluate `simpleUDF($"id")` once, i.e. ```java subExpr1 = simpleUDF($"id"); subExpr2 = functions.length(subExpr1); ``` Snippets of generated codes: Before: ```java /* 040 */ private int project_subExpr_1(long project_expr_0_0) { /* 041 */ boolean project_isNull_6 = false; /* 042 */ UTF8String project_value_6 = null; /* 043 */ if (!false) { /* 044 */ project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0)); /* 045 */ } /* 046 */ /* 047 */ Object project_arg_1 = null; /* 048 */ if (project_isNull_6) { /* 049 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null); /* 050 */ } else { /* 051 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6); /* 052 */ } /* 053 */ /* 054 */ UTF8String project_result_1 = null; /* 055 */ try { /* 056 */ project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1) ); /* 057 */ } catch (Throwable e) { /* 058 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( /* 059 */ "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e); /* 060 */ } /* 061 */ /* 062 */ boolean project_isNull_5 = project_result_1 == null; /* 063 */ UTF8String project_value_5 = null; /* 064 */ if (!project_isNull_5) { /* 065 */ project_value_5 = project_result_1; /* 066 */ } /* 067 */ boolean project_isNull_4 = project_isNull_5; /* 068 */ int project_value_4 = -1; /* 069 */ /* 070 */ if (!project_isNull_5) { /* 071 */ project_value_4 = (project_value_5).numChars(); /* 072 */ } /* 073 */ project_subExprIsNull_1 = project_isNull_4; /* 074 */ return project_value_4; /* 075 */ } ... /* 149 */ private UTF8String project_subExpr_0(long project_expr_0_0) { /* 150 */ boolean project_isNull_2 = false; /* 151 */ UTF8String project_value_2 = null; /* 152 */ if (!false) { /* 153 */ project_value_2 = UTF8String.fromString(String.valueOf(project_expr_0_0)); /* 154 */ } /* 155 */ /* 156 */ Object project_arg_0 = null; /* 157 */ if (project_isNull_2) { /* 158 */ project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(null); /* 159 */ } else { /* 160 */ project_arg_0 = ((scala.Function1[]) references[1] /* converters */)[0].apply(project_value_2); /* 161 */ } /* 162 */ /* 163 */ UTF8String project_result_0 = null; /* 164 */ try { /* 165 */ project_result_0 = (UTF8String)((scala.Function1[]) references[1] /* converters */)[1].apply(((scala.Function1) references[2] /* udf */).apply(project_arg_0) ); /* 166 */ } catch (Throwable e) { /* 167 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( /* 168 */ "DataFrameSuite$$Lambda$6418/1507986601", "string", "string", e); /* 169 */ } /* 170 */ /* 171 */ boolean project_isNull_1 = project_result_0 == null; /* 172 */ UTF8String project_value_1 = null; /* 173 */ if (!project_isNull_1) { /* 174 */ project_value_1 = project_result_0; /* 175 */ } /* 176 */ project_subExprIsNull_0 = project_isNull_1; /* 177 */ return project_value_1; /* 178 */ } ``` After: ```java /* 041 */ private void project_subExpr_1(long project_expr_0_0) { /* 042 */ boolean project_isNull_8 = project_subExprIsNull_0; /* 043 */ int project_value_8 = -1; /* 044 */ /* 045 */ if (!project_subExprIsNull_0) { /* 046 */ project_value_8 = (project_mutableStateArray_0[0]).numChars(); /* 047 */ } /* 048 */ project_subExprIsNull_1 = project_isNull_8; /* 049 */ project_subExprValue_0 = project_value_8; /* 050 */ } /* 056 */ ... /* 123 */ /* 124 */ private void project_subExpr_0(long project_expr_0_0) { /* 125 */ boolean project_isNull_6 = false; /* 126 */ UTF8String project_value_6 = null; /* 127 */ if (!false) { /* 128 */ project_value_6 = UTF8String.fromString(String.valueOf(project_expr_0_0)); /* 129 */ } /* 130 */ /* 131 */ Object project_arg_1 = null; /* 132 */ if (project_isNull_6) { /* 133 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(null); /* 134 */ } else { /* 135 */ project_arg_1 = ((scala.Function1[]) references[3] /* converters */)[0].apply(project_value_6); /* 136 */ } /* 137 */ /* 138 */ UTF8String project_result_1 = null; /* 139 */ try { /* 140 */ project_result_1 = (UTF8String)((scala.Function1[]) references[3] /* converters */)[1].apply(((scala.Function1) references[4] /* udf */).apply(project_arg_1) ); /* 141 */ } catch (Throwable e) { /* 142 */ throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( /* 143 */ "DataFrameSuite$$Lambda$6430/2004847941", "string", "string", e); /* 144 */ } /* 145 */ /* 146 */ boolean project_isNull_5 = project_result_1 == null; /* 147 */ UTF8String project_value_5 = null; /* 148 */ if (!project_isNull_5) { /* 149 */ project_value_5 = project_result_1; /* 150 */ } /* 151 */ project_subExprIsNull_0 = project_isNull_5; /* 152 */ project_mutableStateArray_0[0] = project_value_5; /* 153 */ } ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #32699 from viirya/improve-subexpr. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 9d0d4ed commit dbf0b50

File tree

2 files changed

+55
-19
lines changed

2 files changed

+55
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,21 +1039,25 @@ class CodegenContext extends Logging {
10391039
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
10401040
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
10411041
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
1042-
val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
1042+
val localSubExprEliminationExprsForNonSplit =
1043+
mutable.HashMap.empty[Expression, SubExprEliminationState]
10431044

10441045
// Add each expression tree and compute the common subexpressions.
10451046
expressions.foreach(equivalentExpressions.addExprTree(_))
10461047

10471048
// Get all the expressions that appear at least twice and set up the state for subexpression
10481049
// elimination.
10491050
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
1050-
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))
10511051

1052-
lazy val nonSplitExprCode = {
1053-
commonExprs.zip(commonExprVals).map { case (exprs, eval) =>
1054-
// Generate the code for this expression tree.
1055-
val state = SubExprEliminationState(eval.isNull, eval.value)
1056-
exprs.foreach(localSubExprEliminationExprs.put(_, state))
1052+
val nonSplitExprCode = {
1053+
commonExprs.map { exprs =>
1054+
val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
1055+
val eval = exprs.head.genCode(this)
1056+
// Generate the code for this expression tree.
1057+
val state = SubExprEliminationState(eval.isNull, eval.value)
1058+
exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
1059+
Seq(eval)
1060+
}.head
10571061
eval.code.toString
10581062
}
10591063
}
@@ -1068,11 +1072,19 @@ class CodegenContext extends Logging {
10681072
}.unzip
10691073

10701074
val splitThreshold = SQLConf.get.methodSplitThreshold
1071-
val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) {
1075+
1076+
val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) {
10721077
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
1073-
commonExprs.zipWithIndex.map { case (exprs, i) =>
1078+
val localSubExprEliminationExprs =
1079+
mutable.HashMap.empty[Expression, SubExprEliminationState]
1080+
1081+
val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) =>
10741082
val expr = exprs.head
1075-
val eval = commonExprVals(i)
1083+
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
1084+
Seq(expr.genCode(this))
1085+
}.head
1086+
1087+
val value = addMutableState(javaType(expr.dataType), "subExprValue")
10761088

10771089
val isNullLiteral = eval.isNull match {
10781090
case TrueLiteral | FalseLiteral => true
@@ -1090,34 +1102,33 @@ class CodegenContext extends Logging {
10901102
val inputVars = inputVarsForAllFuncs(i)
10911103
val argList =
10921104
inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}")
1093-
val returnType = javaType(expr.dataType)
10941105
val fn =
10951106
s"""
1096-
|private $returnType $fnName(${argList.mkString(", ")}) {
1107+
|private void $fnName(${argList.mkString(", ")}) {
10971108
| ${eval.code}
10981109
| $isNullEvalCode
1099-
| return ${eval.value};
1110+
| $value = ${eval.value};
11001111
|}
11011112
""".stripMargin
11021113

1103-
val value = freshName("subExprValue")
1104-
val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType))
1114+
val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType))
11051115
exprs.foreach(localSubExprEliminationExprs.put(_, state))
11061116
val inputVariables = inputVars.map(_.variableName).mkString(", ")
1107-
s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);"
1117+
s"${addNewFunction(fnName, fn)}($inputVariables);"
11081118
}
1119+
(splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate)
11091120
} else {
11101121
if (Utils.isTesting) {
11111122
throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH)
11121123
} else {
11131124
logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH))
1114-
nonSplitExprCode
1125+
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty)
11151126
}
11161127
}
11171128
} else {
1118-
nonSplitExprCode
1129+
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty)
11191130
}
1120-
SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten)
1131+
SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten)
11211132
}
11221133

11231134
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2882,6 +2882,31 @@ class DataFrameSuite extends QueryTest
28822882
df2.collect()
28832883
assert(accum.value == 15)
28842884
}
2885+
2886+
test("SPARK-35560: Remove redundant subexpression evaluation in nested subexpressions") {
2887+
Seq(1, Int.MaxValue).foreach { splitThreshold =>
2888+
withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> splitThreshold.toString) {
2889+
val accum = sparkContext.longAccumulator("call")
2890+
val simpleUDF = udf((s: String) => {
2891+
accum.add(1)
2892+
s
2893+
})
2894+
2895+
// Common exprs:
2896+
// 1. simpleUDF($"id")
2897+
// 2. functions.length(simpleUDF($"id"))
2898+
// We should only evaluate `simpleUDF($"id")` once, i.e.
2899+
// subExpr1 = simpleUDF($"id");
2900+
// subExpr2 = functions.length(subExpr1);
2901+
val df = spark.range(5).select(
2902+
when(functions.length(simpleUDF($"id")) === 1, lower(simpleUDF($"id")))
2903+
.when(functions.length(simpleUDF($"id")) === 0, upper(simpleUDF($"id")))
2904+
.otherwise(simpleUDF($"id")).as("output"))
2905+
df.collect()
2906+
assert(accum.value == 5)
2907+
}
2908+
}
2909+
}
28852910
}
28862911

28872912
case class GroupByKey(a: Int, b: Int)

0 commit comments

Comments
 (0)