From f866b110a27e4405c1e51fd0684d6d22e19ed2a6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 28 Aug 2017 21:46:00 +0900 Subject: [PATCH 1/4] Split aggregation into small functions --- .../expressions/codegen/CodeGenerator.scala | 24 +++ .../expressions/CodeGenerationSuite.scala | 18 ++ .../aggregate/HashAggregateExec.scala | 185 +++++++++++++++--- .../execution/WholeStageCodegenSuite.scala | 16 +- 4 files changed, 215 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5c9e604a8d29..cb94f31e5a3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.io.ByteArrayInputStream +import java.lang.Character._ import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ @@ -1103,6 +1104,29 @@ class CodegenContext { } } +object CodegenContext { + + private val javaKeywords = Set( + "abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const", + "continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float", + "for", "goto", "if", "implements", "import", "instanceof", "int", "interface", "long", "native", + "new", "null", "package", "private", "protected", "public", "return", "short", "static", + "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", + "try", "void", "volatile", "while" + ) + + /** + * Returns true if the given `str` is a valid java identifier. + */ + def isJavaIdentifier(str: String): Boolean = str match { + case null | "" => + false + case _ => + !javaKeywords.contains(str) && isJavaIdentifierStart(str.charAt(0)) && + (1 until str.length).forall(i => isJavaIdentifierPart(str.charAt(i))) + } +} + /** * A wrapper for generated class, defines a `generate` method so that we can pass extra objects * into generated class. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 40bf29bb3b57..f67eefab4d1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -394,4 +394,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Map("add" -> Literal(1))).genCode(ctx) assert(ctx.mutableStates.isEmpty) } + + test("SPARK-21870 check if CodegenContext.isJavaIdentifier works correctly") { + import CodegenContext.isJavaIdentifier + // positive cases + assert(isJavaIdentifier("agg_value")) + assert(isJavaIdentifier("agg_value1")) + assert(isJavaIdentifier("bhj_value4")) + assert(isJavaIdentifier("smj_value6")) + assert(isJavaIdentifier("rdd_value7")) + assert(isJavaIdentifier("scan_isNull")) + assert(isJavaIdentifier("test")) + // negative cases + assert(!isJavaIdentifier("true")) + assert(!isJavaIdentifier("false")) + assert(!isJavaIdentifier("390239")) + assert(!isJavaIdentifier(""""literal"""")) + assert(!isJavaIdentifier(""""double"""")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9cadd13999e7..6bb932b150ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD @@ -257,6 +259,85 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { + // `argSet` collects all the pairs of variable names and their types, the first in the pair is + // a type name and the second is a variable name. + val argSet = mutable.Set[(String, String)]() + val stack = mutable.Stack[Expression](aggExpr) + while (stack.nonEmpty) { + stack.pop() match { + case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { + argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { + argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) + case ref: BoundReference + if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { + argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { + argSet += (("boolean", isNull)) + } + case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) + case e => + stack.pushAll(e.children) + } + } + + argSet.toSet + } + + // Splits aggregate code into small functions because JVMs does not compile too long functions + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { + aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum length of parameters in non-static Java methods is 254, but a parameter of + // type long or double contributes two units to the length. So, this method gives up + // splitting the code if the parameter length goes over 127. + val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { + case null | "" => 127 + case param => param.toInt + } + if (args.size <= maxParamNumInJavaMethod) { + val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}") + val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ") + val doAggValFuncName = ctx.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${evalAndUpdateCodes(i)} + | } + """.stripMargin) + + val inputVariables = args.map(_._2).mkString(", ") + s"$doAggValFuncName($inputVariables);" + } else { + evalAndUpdateCodes(i) + } + } + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -269,28 +350,50 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input + + // We need to copy the aggregation buffer to local variables first because each aggregate + // function directly updates the buffer when it finishes. + val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) => + val isNull = ctx.freshName("localBufIsNull") + val value = ctx.freshName("localBufValue") + val initLocalVars = s""" + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + """.stripMargin + ExprCode(initLocalVars, isNull, value) + } + + val initLocalBufVar = evaluateVariables(localBufVars) + + ctx.currentVars = localBufVars ++ input val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - // aggregate buffer should be updated atomic - val updates = aggVals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = aggVals.zipWithIndex.map { case (ev, i) => s""" + | // evaluate aggregate function + | ${ev.code} + | // update aggregation buffer | ${bufVars(i).isNull} = ${ev.isNull}; | ${bufVars(i).value} = ${ev.value}; """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( + ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states) + s""" | // do aggregate + | // copy aggregation buffer to the local + | $initLocalBufVar | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} """.stripMargin } @@ -825,52 +928,86 @@ case class HashAggregateExec( ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input val updateRowInRegularHashMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( + ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, + Seq(("InternalRow", unsafeRowBuffer))) + s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} + | // do aggregate + | // copy aggregation row buffer to the local + | $initLocalRowBuffer + | // common sub-expressions + | $effectiveCodes + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} """.stripMargin } val updateRowInHashMap: String = { if (isFastHashMapEnabled) { - ctx.INPUT_ROW = fastRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localFastRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) + val updateColumnCode = ctx.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update fast row + | $updateColumnCode + """.stripMargin } + val updateAggValCode = splitAggregateExpressions( + ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, + Seq(("InternalRow", fastRowBuffer))) + // If fast hash map is on, we first generate code to update row in fast hash map, if the // previous loop up hit fast hash map. Otherwise, update row in regular hash map. s""" |if ($fastRowBuffer != null) { + | // copy aggregation row buffer to the local + | $initLocalRowBuffer | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} |} else { | $updateRowInRegularHashMap |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c4..1143f6c50138 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -211,11 +211,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) - val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) - assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val (_, smallCodeSize) = CodeGenerator.compile(codeWithShortFunctions) val codeWithLongFunctions = genGroupByCode(20) - val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) - assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val (_, largeCodeSize) = CodeGenerator.compile(codeWithLongFunctions) + // Just checking if long functions have the large value of max code size + assert(largeCodeSize > smallCodeSize) } test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { @@ -236,4 +236,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21870 check the case where the number of parameters goes over the limit") { + withSQLConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod" -> "2") { + sql("CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES (1, 1, 1) AS t(a, b, c)") + val df = sql("SELECT SUM(a + b + c) AS sum FROM t") + assert(df.collect === Seq(Row(3))) + } + } } From 1ed010f49cfdb9e0bfc3bdf65e974bda191f76e0 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 12 Dec 2017 00:45:54 +0900 Subject: [PATCH 2/4] Add new conf for maxParamNumInJavaMethod --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++++++++ .../sql/execution/aggregate/HashAggregateExec.scala | 13 +++---------- .../sql/execution/WholeStageCodegenSuite.scala | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1121444cc938..cae5640b2f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -607,6 +607,17 @@ object SQLConf { .intConf .createWithDefault(100) + val MAX_PARAM_NUM_IN_JAVA_METHOD = + buildConf("spark.sql.codegen.maxParamNumInJavaMethod") + .internal() + .doc("The maximum number of parameters in codegened Java functions. When a function " + + "exceeds this threshold, the code generator gives up splitting the function code. " + + "This default value is 127 because the maximum length of parameters in non-static Java " + + "methods is 254 and a parameter of type long or double contributes " + + "two units to the length.") + .intConf + .createWithDefault(127) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + @@ -1156,6 +1167,8 @@ class SQLConf extends Serializable with Logging { def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def maxParamNumInJavaMethod: Int = getConf(MAX_PARAM_NUM_IN_JAVA_METHOD) + def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6bb932b150ea..c5b35e829c74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -309,18 +309,11 @@ case class HashAggregateExec( subExprs: Map[Expression, SubExprEliminationState], otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { aggExprs.zipWithIndex.map { case (aggExpr, i) => - // The maximum length of parameters in non-static Java methods is 254, but a parameter of - // type long or double contributes two units to the length. So, this method gives up - // splitting the code if the parameter length goes over 127. val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq - // This is for testing/benchmarking only - val maxParamNumInJavaMethod = - sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { - case null | "" => 127 - case param => param.toInt - } - if (args.size <= maxParamNumInJavaMethod) { + // This method gives up splitting the code if the parameter length goes over + // `maxParamNumInJavaMethod`. + if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) { val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}") val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ") val doAggValFuncName = ctx.addNewFunction(doAggVal, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 1143f6c50138..67fdb49e9101 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -238,7 +238,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("SPARK-21870 check the case where the number of parameters goes over the limit") { - withSQLConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod" -> "2") { + withSQLConf("spark.sql.codegen.maxParamNumInJavaMethod" -> "2") { sql("CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES (1, 1, 1) AS t(a, b, c)") val df = sql("SELECT SUM(a + b + c) AS sum FROM t") assert(df.collect === Seq(Row(3))) From 0e5d366dfce50ec3f879193f78c6a77213680926 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 12 Dec 2017 08:16:14 +0900 Subject: [PATCH 3/4] Add indents --- .../sql/execution/aggregate/HashAggregateExec.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c5b35e829c74..8caeff8dbabe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -946,7 +946,10 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, + ctx, + boundUpdateExpr, + evalAndUpdateCodes, + subExprs.states, Seq(("InternalRow", unsafeRowBuffer))) s""" @@ -988,7 +991,10 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, + ctx, + boundUpdateExpr, + evalAndUpdateCodes, + subExprs.states, Seq(("InternalRow", fastRowBuffer))) // If fast hash map is on, we first generate code to update row in fast hash map, if the From 2555e5a3eb1223a19f3f523abc7c1e5a9251e528 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 12 Dec 2017 23:36:47 +0900 Subject: [PATCH 4/4] Add parameter names --- .../aggregate/HashAggregateExec.scala | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8caeff8dbabe..83b0f807f541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -262,19 +262,19 @@ case class HashAggregateExec( // Extracts all the input variable references for a given `aggExpr`. This result will be used // to split aggregation into small functions. private def getInputVariableReferences( - ctx: CodegenContext, - aggExpr: Expression, + context: CodegenContext, + aggregateExpression: Expression, subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { // `argSet` collects all the pairs of variable names and their types, the first in the pair is // a type name and the second is a variable name. val argSet = mutable.Set[(String, String)]() - val stack = mutable.Stack[Expression](aggExpr) + val stack = mutable.Stack[Expression](aggregateExpression) while (stack.nonEmpty) { stack.pop() match { case e if subExprs.contains(e) => val exprCode = subExprs(e) if (CodegenContext.isJavaIdentifier(exprCode.value)) { - argSet += ((ctx.javaType(e.dataType), exprCode.value)) + argSet += ((context.javaType(e.dataType), exprCode.value)) } if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { argSet += (("boolean", exprCode.isNull)) @@ -282,17 +282,17 @@ case class HashAggregateExec( // Since the children possibly has common expressions, we push them here stack.pushAll(e.children) case ref: BoundReference - if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => - val value = ctx.currentVars(ref.ordinal).value - val isNull = ctx.currentVars(ref.ordinal).isNull + if context.currentVars != null && context.currentVars(ref.ordinal) != null => + val value = context.currentVars(ref.ordinal).value + val isNull = context.currentVars(ref.ordinal).isNull if (CodegenContext.isJavaIdentifier(value)) { - argSet += ((ctx.javaType(ref.dataType), value)) + argSet += ((context.javaType(ref.dataType), value)) } if (CodegenContext.isJavaIdentifier(isNull)) { argSet += (("boolean", isNull)) } case _: BoundReference => - argSet += (("InternalRow", ctx.INPUT_ROW)) + argSet += (("InternalRow", context.INPUT_ROW)) case e => stack.pushAll(e.children) } @@ -303,30 +303,30 @@ case class HashAggregateExec( // Splits aggregate code into small functions because JVMs does not compile too long functions private def splitAggregateExpressions( - ctx: CodegenContext, - aggExprs: Seq[Expression], - evalAndUpdateCodes: Seq[String], + context: CodegenContext, + aggregateExpressions: Seq[Expression], + codes: Seq[String], subExprs: Map[Expression, SubExprEliminationState], otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { - aggExprs.zipWithIndex.map { case (aggExpr, i) => - val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + aggregateExpressions.zipWithIndex.map { case (aggExpr, i) => + val args = (getInputVariableReferences(context, aggExpr, subExprs) ++ otherArgs).toSeq // This method gives up splitting the code if the parameter length goes over // `maxParamNumInJavaMethod`. if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) { - val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}") + val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}") val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ") - val doAggValFuncName = ctx.addNewFunction(doAggVal, + val doAggValFuncName = context.addNewFunction(doAggVal, s""" | private void $doAggVal($argList) throws java.io.IOException { - | ${evalAndUpdateCodes(i)} + | ${codes(i)} | } """.stripMargin) val inputVariables = args.map(_._2).mkString(", ") s"$doAggValFuncName($inputVariables);" } else { - evalAndUpdateCodes(i) + codes(i) } } } @@ -377,7 +377,10 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states) + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states) s""" | // do aggregate @@ -946,11 +949,11 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - ctx, - boundUpdateExpr, - evalAndUpdateCodes, - subExprs.states, - Seq(("InternalRow", unsafeRowBuffer))) + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states, + otherArgs = Seq(("InternalRow", unsafeRowBuffer))) s""" | // do aggregate @@ -991,11 +994,11 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - ctx, - boundUpdateExpr, - evalAndUpdateCodes, - subExprs.states, - Seq(("InternalRow", fastRowBuffer))) + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states, + otherArgs = Seq(("InternalRow", fastRowBuffer))) // If fast hash map is on, we first generate code to update row in fast hash map, if the // previous loop up hit fast hash map. Otherwise, update row in regular hash map.