Skip to content

Commit 2c16267

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-22669][SQL] Avoid unnecessary function calls in code generation
## What changes were proposed in this pull request? In many parts of the codebase for code generation, we are splitting the code to avoid exceptions due to the 64KB method size limit. This is generating a lot of methods which are called every time, even though sometime this is not needed. As pointed out here: apache#19752 (comment), this is a not negligible overhead which can be avoided. The PR applies the same approach used in apache#19752 also to the other places where this was feasible. ## How was this patch tested? existing UTs. Author: Marco Gaido <mgaido@hortonworks.com> Closes apache#19860 from mgaido91/SPARK-22669.
1 parent f23dddf commit 2c16267

2 files changed

Lines changed: 140 additions & 69 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 93 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,51 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7575
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
7676
ctx.addMutableState(ctx.javaType(dataType), ev.value)
7777

78+
// all the evals are meant to be in a do { ... } while (false); loop
7879
val evals = children.map { e =>
7980
val eval = e.genCode(ctx)
8081
s"""
81-
if (${ev.isNull}) {
82-
${eval.code}
83-
if (!${eval.isNull}) {
84-
${ev.isNull} = false;
85-
${ev.value} = ${eval.value};
86-
}
87-
}
88-
"""
82+
|${eval.code}
83+
|if (!${eval.isNull}) {
84+
| ${ev.isNull} = false;
85+
| ${ev.value} = ${eval.value};
86+
| continue;
87+
|}
88+
""".stripMargin
8989
}
90+
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
91+
evals.mkString("\n")
92+
} else {
93+
ctx.splitExpressions(evals, "coalesce",
94+
("InternalRow", ctx.INPUT_ROW) :: Nil,
95+
makeSplitFunction = {
96+
func =>
97+
s"""
98+
|do {
99+
| $func
100+
|} while (false);
101+
""".stripMargin
102+
},
103+
foldFunctions = { funcCalls =>
104+
funcCalls.map { funcCall =>
105+
s"""
106+
|$funcCall;
107+
|if (!${ev.isNull}) {
108+
| continue;
109+
|}
110+
""".stripMargin
111+
}.mkString
112+
})
113+
}
90114

91-
ev.copy(code = s"""
92-
${ev.isNull} = true;
93-
${ev.value} = ${ctx.defaultValue(dataType)};
94-
${ctx.splitExpressions(evals)}""")
115+
ev.copy(code =
116+
s"""
117+
|${ev.isNull} = true;
118+
|${ev.value} = ${ctx.defaultValue(dataType)};
119+
|do {
120+
| $code
121+
|} while (false);
122+
""".stripMargin)
95123
}
96124
}
97125

@@ -358,53 +386,70 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
358386

359387
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
360388
val nonnull = ctx.freshName("nonnull")
389+
// all evals are meant to be inside a do { ... } while (false); loop
361390
val evals = children.map { e =>
362391
val eval = e.genCode(ctx)
363392
e.dataType match {
364393
case DoubleType | FloatType =>
365394
s"""
366-
if ($nonnull < $n) {
367-
${eval.code}
368-
if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
369-
$nonnull += 1;
370-
}
371-
}
372-
"""
395+
|if ($nonnull < $n) {
396+
| ${eval.code}
397+
| if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
398+
| $nonnull += 1;
399+
| }
400+
|} else {
401+
| continue;
402+
|}
403+
""".stripMargin
373404
case _ =>
374405
s"""
375-
if ($nonnull < $n) {
376-
${eval.code}
377-
if (!${eval.isNull}) {
378-
$nonnull += 1;
379-
}
380-
}
381-
"""
406+
|if ($nonnull < $n) {
407+
| ${eval.code}
408+
| if (!${eval.isNull}) {
409+
| $nonnull += 1;
410+
| }
411+
|} else {
412+
| continue;
413+
|}
414+
""".stripMargin
382415
}
383416
}
384417

385418
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
386-
evals.mkString("\n")
387-
} else {
388-
ctx.splitExpressions(
389-
expressions = evals,
390-
funcName = "atLeastNNonNulls",
391-
arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil,
392-
returnType = "int",
393-
makeSplitFunction = { body =>
394-
s"""
395-
$body
396-
return $nonnull;
397-
"""
398-
},
399-
foldFunctions = { funcCalls =>
400-
funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n")
401-
}
402-
)
403-
}
419+
evals.mkString("\n")
420+
} else {
421+
ctx.splitExpressions(
422+
expressions = evals,
423+
funcName = "atLeastNNonNulls",
424+
arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil,
425+
returnType = ctx.JAVA_INT,
426+
makeSplitFunction = { body =>
427+
s"""
428+
|do {
429+
| $body
430+
|} while (false);
431+
|return $nonnull;
432+
""".stripMargin
433+
},
434+
foldFunctions = { funcCalls =>
435+
funcCalls.map(funcCall =>
436+
s"""
437+
|$nonnull = $funcCall;
438+
|if ($nonnull >= $n) {
439+
| continue;
440+
|}
441+
""".stripMargin).mkString("\n")
442+
}
443+
)
444+
}
404445

405-
ev.copy(code = s"""
406-
int $nonnull = 0;
407-
$code
408-
boolean ${ev.value} = $nonnull >= $n;""", isNull = "false")
446+
ev.copy(code =
447+
s"""
448+
|${ctx.JAVA_INT} $nonnull = 0;
449+
|do {
450+
| $code
451+
|} while (false);
452+
|${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
453+
""".stripMargin, isNull = "false")
409454
}
410455
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -234,36 +234,62 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
234234
}
235235

236236
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
237+
val javaDataType = ctx.javaType(value.dataType)
237238
val valueGen = value.genCode(ctx)
238239
val listGen = list.map(_.genCode(ctx))
239240
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
240241
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
241242
val valueArg = ctx.freshName("valueArg")
243+
// All the blocks are meant to be inside a do { ... } while (false); loop.
244+
// The evaluation of variables can be stopped when we find a matching value.
242245
val listCode = listGen.map(x =>
243246
s"""
244-
if (!${ev.value}) {
245-
${x.code}
246-
if (${x.isNull}) {
247-
${ev.isNull} = true;
248-
} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
249-
${ev.isNull} = false;
250-
${ev.value} = true;
247+
|${x.code}
248+
|if (${x.isNull}) {
249+
| ${ev.isNull} = true;
250+
|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
251+
| ${ev.isNull} = false;
252+
| ${ev.value} = true;
253+
| continue;
254+
|}
255+
""".stripMargin)
256+
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
257+
listCode.mkString("\n")
258+
} else {
259+
ctx.splitExpressions(
260+
expressions = listCode,
261+
funcName = "valueIn",
262+
arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil,
263+
makeSplitFunction = { body =>
264+
s"""
265+
|do {
266+
| $body
267+
|} while (false);
268+
""".stripMargin
269+
},
270+
foldFunctions = { funcCalls =>
271+
funcCalls.map(funcCall =>
272+
s"""
273+
|$funcCall;
274+
|if (${ev.value}) {
275+
| continue;
276+
|}
277+
""".stripMargin).mkString("\n")
251278
}
252-
}
253-
""")
254-
val listCodes = ctx.splitExpressions(
255-
expressions = listCode,
256-
funcName = "valueIn",
257-
extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil)
258-
ev.copy(code = s"""
259-
${valueGen.code}
260-
${ev.value} = false;
261-
${ev.isNull} = ${valueGen.isNull};
262-
if (!${ev.isNull}) {
263-
${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value};
264-
$listCodes
279+
)
265280
}
266-
""")
281+
ev.copy(code =
282+
s"""
283+
|${valueGen.code}
284+
|${ev.value} = false;
285+
|${ev.isNull} = ${valueGen.isNull};
286+
|if (!${ev.isNull}) {
287+
| $javaDataType $valueArg = ${valueGen.value};
288+
| do {
289+
| $code
290+
| } while (false);
291+
|}
292+
""".stripMargin)
267293
}
268294

269295
override def sql: String = {

0 commit comments

Comments
 (0)