From 8a4fd61d15ded1f0456a80883f20cc1482246887 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 28 Aug 2017 21:46:00 +0900 Subject: [PATCH 01/20] Split aggregation into small functions --- .../expressions/codegen/CodeGenerator.scala | 43 ++++ .../apache/spark/sql/internal/SQLConf.scala | 11 + .../aggregate/HashAggregateExec.scala | 226 ++++++++++++++++-- .../sql-tests/inputs/group-analytics.sql | 3 + .../sql-tests/inputs/group-by-ordinal.sql | 3 + .../resources/sql-tests/inputs/group-by.sql | 3 + .../sql-tests/inputs/grouping_set.sql | 3 + 7 files changed, 269 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 95fad412002e..236ba18168be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1612,6 +1612,49 @@ object CodeGenerator extends Logging { } } + /** + * Extracts all the input variables from references and subexpression elimination states + * for a given `expr`. This result will be used to split the generated code of + * expressions into multiple functions. + */ + def getLocalInputVariableValues( + context: CodegenContext, + expr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Seq[((String, DataType), Expression)] = { + val argMap = mutable.Map[(String, DataType), Expression]() + val stack = mutable.Stack[Expression](expr) + while (stack.nonEmpty) { + stack.pop() match { + case e if subExprs.contains(e) => + val exprCode = subExprs(e) + val SubExprEliminationState(isNull, value) = exprCode + if (value.isInstanceOf[VariableValue]) { + argMap += (value.code, e.dataType) -> e + } + if (isNull.isInstanceOf[VariableValue]) { + argMap += (isNull.code, BooleanType) -> e + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) + case ref: BoundReference + if context.currentVars != null && context.currentVars(ref.ordinal) != null => + val ExprCode(_, isNull, value) = context.currentVars(ref.ordinal) + if (value.isInstanceOf[VariableValue]) { + argMap += (value.code, ref.dataType) -> ref + } + if (isNull.isInstanceOf[VariableValue]) { + argMap += (isNull.code, BooleanType) -> ref + } + case ref: BoundReference => + argMap += (context.INPUT_ROW, ObjectType(classOf[InternalRow])) -> ref + case e => + stack.pushAll(e.children) + } + } + + argMap.toSeq + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 52990cb6a244..85e8763bdb60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1047,6 +1047,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CODEGEN_SPLIT_AGGREGATE_FUNC = + buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled") + .internal() + .doc("When true, the code generator would aggregate code into individual methods " + + "instead of a single big method. This can be used to avoid oversized function that " + + "can miss the opportunity of JIT optimization.") + .booleanConf + .createWithDefault(false) + val MAX_NESTED_VIEW_DEPTH = buildConf("spark.sql.view.maxNestedViewDepth") .internal() @@ -2310,6 +2319,8 @@ class SQLConf extends Serializable with Logging { def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC) + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4a95f7638133..4a8c296194e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.types.{DataType, DecimalType, ObjectType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -255,6 +255,41 @@ case class HashAggregateExec( """.stripMargin } + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. Note that different from `CodeGenerator.splitExpressions`, + // we will extract input variables from references and subexpression elimination states + // for each aggregate expression, then pass them to it. + private def splitAggregateExpressions( + context: CodegenContext, + aggregateExpressions: Seq[Expression], + codes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + bufferInput: Option[(String, DataType)] = None): Seq[String] = { + aggregateExpressions.zipWithIndex.map { case (aggExpr, i) => + val inputVars = CodeGenerator.getLocalInputVariableValues(context, aggExpr, subExprs) + val args = inputVars.map(_._1) ++ bufferInput.map(_ :: Nil).getOrElse(Nil) + val paramLength = CodeGenerator.calculateParamLength(inputVars.map(_._2)) + + (if (bufferInput.isDefined) 1 else 0) + + // This method gives up splitting the code if the parameter length goes over the limit + if (CodeGenerator.isValidParamLength(paramLength)) { + val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}") + val argList = args.map(v => s"${CodeGenerator.javaType(v._2)} ${v._1}").mkString(", ") + val doAggValFuncName = context.addNewFunction(doAggVal, + s""" + | private void $doAggVal($argList) throws java.io.IOException { + | ${codes(i)} + | } + """.stripMargin) + + val inputVariables = args.map(_._1).mkString(", ") + s"$doAggValFuncName($inputVariables);" + } else { + codes(i) + } + } + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -267,29 +302,81 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input - val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - // aggregate buffer should be updated atomic - val updates = aggVals.zipWithIndex.map { case (ev, i) => + + if (!conf.codegenSplitAggregateFunc) { + ctx.currentVars = bufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => + s""" + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin + } else { + // We need to copy the aggregation buffer to local variables first because each aggregate + // function directly updates the buffer when it finishes. + val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) => + val isNull = ctx.freshName("localBufIsNull") + val value = ctx.freshName("localBufValue") + val initLocalVars = code""" + | boolean $isNull = ${ev.isNull}; + | ${CodeGenerator.javaType(e.dataType)} $value = ${ev.value}; + """.stripMargin + ExprCode(initLocalVars, JavaCode.isNullVariable(isNull), + JavaCode.variable(value, e.dataType)) + } + + val initLocalBufVar = evaluateVariables(localBufVars) + + ctx.currentVars = localBufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + + val evalAndUpdateCodes = aggVals.zipWithIndex.map { case (ev, i) => + s""" + | // evaluate aggregate function + | ${ev.code} + | // update aggregation buffer + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + + val updateAggValCode = splitAggregateExpressions( + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states) + s""" - | ${bufVars(i).isNull} = ${ev.isNull}; - | ${bufVars(i).value} = ${ev.value}; + | // do aggregate + | // copy aggregation buffer to the local + | $initLocalBufVar + | // common sub-expressions + | $effectiveCodes + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} """.stripMargin } - s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} - """.stripMargin } private val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -824,7 +911,7 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - val updateRowInRegularHashMap: String = { + val updateRowInRegularHashMap: String = if (!conf.codegenSplitAggregateFunc) { ctx.INPUT_ROW = unsafeRowBuffer val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) @@ -844,9 +931,51 @@ case class HashAggregateExec( |// update unsafe row buffer |${updateUnsafeRowBuffer.mkString("\n").trim} """.stripMargin + } else { + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + val updateColumnCode = + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin + } + + val updateAggValCode = splitAggregateExpressions( + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states, + bufferInput = Some((unsafeRowBuffer, ObjectType(classOf[InternalRow])))) + + s""" + | // do aggregate + | // copy aggregation row buffer to the local + | $initLocalRowBuffer + | // common sub-expressions + | $effectiveCodes + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} + """.stripMargin } - val updateRowInHashMap: String = { + val updateRowInHashMap: String = if (!conf.codegenSplitAggregateFunc) { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer @@ -892,6 +1021,57 @@ case class HashAggregateExec( } else { updateRowInRegularHashMap } + } else { + if (isFastHashMapEnabled) { + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localFastRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + + val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + val updateColumnCode = CodeGenerator.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) + s""" + | // evaluate aggregate function + | ${ev.code} + | // update fast row + | $updateColumnCode + """.stripMargin + } + + val updateAggValCode = splitAggregateExpressions( + context = ctx, + aggregateExpressions = boundUpdateExpr, + codes = evalAndUpdateCodes, + subExprs = subExprs.states, + bufferInput = Some((fastRowBuffer, ObjectType(classOf[InternalRow])))) + + // If fast hash map is on, we first generate code to update row in fast hash map, if the + // previous loop up hit fast hash map. Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // copy aggregation row buffer to the local + | $initLocalRowBuffer + | // common sub-expressions + | $effectiveCodes + | // process aggregate functions to update aggregation buffer + | ${updateAggValCode.mkString("\n")} + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + } else { + updateRowInRegularHashMap + } } val declareRowBuffer: String = if (isFastHashMapEnabled) { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 9721f8c60ebc..d102512646a3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -1,3 +1,6 @@ +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false + CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2) AS testData(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 3144833b608b..7321a5b30f7d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -1,3 +1,6 @@ +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false + -- group by ordinal positions create temporary view data as select * from values diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 66bc90914e0d..5d63825b2b2a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -1,3 +1,6 @@ +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false + -- Test data. CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 6bbde9f38d65..d9f9c9133257 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -1,3 +1,6 @@ +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true +--SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false + CREATE TEMPORARY VIEW grouping AS SELECT * FROM VALUES ("1", "2", "3", 1), ("4", "5", "6", 1), From c9901fce8f6ced33768074a85b13200b2db18f46 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 22 Aug 2019 14:25:07 +0900 Subject: [PATCH 02/20] Address reviews --- .../expressions/codegen/CodeGenerator.scala | 42 ++-- .../aggregate/HashAggregateExec.scala | 236 +++++++++--------- 2 files changed, 139 insertions(+), 139 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 236ba18168be..1897d920a6aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1618,35 +1618,35 @@ object CodeGenerator extends Logging { * expressions into multiple functions. */ def getLocalInputVariableValues( - context: CodegenContext, + ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState]): Seq[((String, DataType), Expression)] = { - val argMap = mutable.Map[(String, DataType), Expression]() + subExprs: Map[Expression, SubExprEliminationState]): Seq[(VariableValue, Expression)] = { + val argMap = mutable.Map[VariableValue, Expression]() + // Collects local variables only in `argMap` + val collectLocalVariable = (ev: ExprValue, expr: Expression) => ev match { + case vv: VariableValue => argMap += vv -> expr + case _ => + } + val stack = mutable.Stack[Expression](expr) while (stack.nonEmpty) { stack.pop() match { case e if subExprs.contains(e) => - val exprCode = subExprs(e) - val SubExprEliminationState(isNull, value) = exprCode - if (value.isInstanceOf[VariableValue]) { - argMap += (value.code, e.dataType) -> e - } - if (isNull.isInstanceOf[VariableValue]) { - argMap += (isNull.code, BooleanType) -> e - } - // Since the children possibly has common expressions, we push them here + val SubExprEliminationState(isNull, value) = subExprs(e) + collectLocalVariable(value, e) + collectLocalVariable(isNull, e) + // Since the children possibly have common subexprs, we push them here stack.pushAll(e.children) + case ref: BoundReference - if context.currentVars != null && context.currentVars(ref.ordinal) != null => - val ExprCode(_, isNull, value) = context.currentVars(ref.ordinal) - if (value.isInstanceOf[VariableValue]) { - argMap += (value.code, ref.dataType) -> ref - } - if (isNull.isInstanceOf[VariableValue]) { - argMap += (isNull.code, BooleanType) -> ref - } + if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal) + collectLocalVariable(value, ref) + collectLocalVariable(isNull, ref) + case ref: BoundReference => - argMap += (context.INPUT_ROW, ObjectType(classOf[InternalRow])) -> ref + argMap += JavaCode.variable(ctx.INPUT_ROW, ObjectType(classOf[InternalRow])) -> ref + case e => stack.pushAll(e.children) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4a8c296194e2..66bc06f8716d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -260,32 +260,32 @@ case class HashAggregateExec( // we will extract input variables from references and subexpression elimination states // for each aggregate expression, then pass them to it. private def splitAggregateExpressions( - context: CodegenContext, + ctx: CodegenContext, aggregateExpressions: Seq[Expression], - codes: Seq[String], + aggEvalCodes: Seq[String], subExprs: Map[Expression, SubExprEliminationState], - bufferInput: Option[(String, DataType)] = None): Seq[String] = { - aggregateExpressions.zipWithIndex.map { case (aggExpr, i) => - val inputVars = CodeGenerator.getLocalInputVariableValues(context, aggExpr, subExprs) + bufferInput: Option[VariableValue] = None): Seq[String] = { + aggregateExpressions.zip(aggEvalCodes).map { case (aggExpr, aggEvalCode) => + val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, aggExpr, subExprs) val args = inputVars.map(_._1) ++ bufferInput.map(_ :: Nil).getOrElse(Nil) val paramLength = CodeGenerator.calculateParamLength(inputVars.map(_._2)) + (if (bufferInput.isDefined) 1 else 0) // This method gives up splitting the code if the parameter length goes over the limit if (CodeGenerator.isValidParamLength(paramLength)) { - val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}") - val argList = args.map(v => s"${CodeGenerator.javaType(v._2)} ${v._1}").mkString(", ") - val doAggValFuncName = context.addNewFunction(doAggVal, + val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}") + val argList = args.map(v => s"${v.javaType} ${v.variableName}").mkString(", ") + val doAggValFuncName = ctx.addNewFunction(doAggVal, s""" | private void $doAggVal($argList) throws java.io.IOException { - | ${codes(i)} + | $aggEvalCode | } """.stripMargin) - val inputVariables = args.map(_._1).mkString(", ") + val inputVariables = args.map(_.variableName).mkString(", ") s"$doAggValFuncName($inputVariables);" } else { - codes(i) + aggEvalCode } } } @@ -303,31 +303,7 @@ case class HashAggregateExec( } } - if (!conf.codegenSplitAggregateFunc) { - ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - // aggregate buffer should be updated atomic - val updates = aggVals.zipWithIndex.map { case (ev, i) => - s""" - | ${bufVars(i).isNull} = ${ev.isNull}; - | ${bufVars(i).value} = ${ev.value}; - """.stripMargin - } - s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} - """.stripMargin - } else { + if (conf.codegenSplitAggregateFunc) { // We need to copy the aggregation buffer to local variables first because each aggregate // function directly updates the buffer when it finishes. val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) => @@ -351,20 +327,20 @@ case class HashAggregateExec( boundUpdateExpr.map(_.genCode(ctx)) } - val evalAndUpdateCodes = aggVals.zipWithIndex.map { case (ev, i) => + val evalAndUpdateCodes = aggVals.zip(bufVars).map { case (ev, bufVar) => s""" | // evaluate aggregate function | ${ev.code} | // update aggregation buffer - | ${bufVars(i).isNull} = ${ev.isNull}; - | ${bufVars(i).value} = ${ev.value}; + | ${bufVar.isNull} = ${ev.isNull}; + | ${bufVar.value} = ${ev.value}; """.stripMargin } val updateAggValCode = splitAggregateExpressions( - context = ctx, + ctx = ctx, aggregateExpressions = boundUpdateExpr, - codes = evalAndUpdateCodes, + aggEvalCodes = evalAndUpdateCodes, subExprs = subExprs.states) s""" @@ -373,9 +349,33 @@ case class HashAggregateExec( | $initLocalBufVar | // common sub-expressions | $effectiveCodes - | // process aggregate functions to update aggregation buffer + | // evaluate aggregate functions and update aggregation buffers | ${updateAggValCode.mkString("\n")} """.stripMargin + } else { + ctx.currentVars = bufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + // aggregate buffer should be updated atomically + val updates = aggVals.zip(bufVars).map { case (ev, bufVar) => + s""" + | ${bufVar.isNull} = ${ev.isNull}; + | ${bufVar.value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin } } @@ -911,27 +911,7 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - val updateRowInRegularHashMap: String = if (!conf.codegenSplitAggregateFunc) { - ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) - } - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin - } else { + val updateRowInRegularHashMap: String = if (conf.codegenSplitAggregateFunc) { // We need to copy the aggregation row buffer to a local row first because each aggregate // function directly updates the buffer when it finishes. val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") @@ -958,11 +938,11 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - context = ctx, + ctx = ctx, aggregateExpressions = boundUpdateExpr, - codes = evalAndUpdateCodes, + aggEvalCodes = evalAndUpdateCodes, subExprs = subExprs.states, - bufferInput = Some((unsafeRowBuffer, ObjectType(classOf[InternalRow])))) + bufferInput = Some(VariableValue(unsafeRowBuffer, classOf[InternalRow]))) s""" | // do aggregate @@ -970,58 +950,32 @@ case class HashAggregateExec( | $initLocalRowBuffer | // common sub-expressions | $effectiveCodes - | // process aggregate functions to update aggregation buffer + | // evaluate aggregate functions and update aggregation buffers | ${updateAggValCode.mkString("\n")} """.stripMargin + } else { + ctx.INPUT_ROW = unsafeRowBuffer + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate function + |${evaluateVariables(unsafeRowBufferEvals)} + |// update unsafe row buffer + |${updateUnsafeRowBuffer.mkString("\n").trim} + """.stripMargin } - val updateRowInHashMap: String = if (!conf.codegenSplitAggregateFunc) { - if (isFastHashMapEnabled) { - if (isVectorizedHashMapEnabled) { - ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true) - } - - // If vectorized fast hash map is on, we first generate code to update row - // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. - // Otherwise, update row in regular hash map. - s""" - |if ($fastRowBuffer != null) { - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} - |} else { - | $updateRowInRegularHashMap - |} - """.stripMargin - } else { - // If row-based hash map is on and the previous loop up hit fast hash map, - // we reuse regular hash buffer to update row of fast hash map. - // Otherwise, update row in regular hash map. - s""" - |// Updates the proper row buffer - |if ($fastRowBuffer != null) { - | $unsafeRowBuffer = $fastRowBuffer; - |} - |$updateRowInRegularHashMap - """.stripMargin - } - } else { - updateRowInRegularHashMap - } - } else { + val updateRowInHashMap: String = if (conf.codegenSplitAggregateFunc) { if (isFastHashMapEnabled) { // We need to copy the aggregation row buffer to a local row first because each aggregate // function directly updates the buffer when it finishes. @@ -1049,11 +1003,11 @@ case class HashAggregateExec( } val updateAggValCode = splitAggregateExpressions( - context = ctx, + ctx = ctx, aggregateExpressions = boundUpdateExpr, - codes = evalAndUpdateCodes, + aggEvalCodes = evalAndUpdateCodes, subExprs = subExprs.states, - bufferInput = Some((fastRowBuffer, ObjectType(classOf[InternalRow])))) + bufferInput = Some(VariableValue(fastRowBuffer, classOf[InternalRow]))) // If fast hash map is on, we first generate code to update row in fast hash map, if the // previous loop up hit fast hash map. Otherwise, update row in regular hash map. @@ -1063,7 +1017,7 @@ case class HashAggregateExec( | $initLocalRowBuffer | // common sub-expressions | $effectiveCodes - | // process aggregate functions to update aggregation buffer + | // evaluate aggregate functions and update aggregation buffers | ${updateAggValCode.mkString("\n")} |} else { | $updateRowInRegularHashMap @@ -1072,6 +1026,52 @@ case class HashAggregateExec( } else { updateRowInRegularHashMap } + } else { + if (isFastHashMapEnabled) { + if (isVectorizedHashMapEnabled) { + ctx.INPUT_ROW = fastRowBuffer + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + CodeGenerator.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true) + } + + // If vectorized fast hash map is on, we first generate code to update row + // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. + // Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(fastRowEvals)} + | // update fast row + | ${updateFastRow.mkString("\n").trim} + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + } else { + // If row-based hash map is on and the previous loop up hit fast hash map, + // we reuse regular hash buffer to update row of fast hash map. + // Otherwise, update row in regular hash map. + s""" + |// Updates the proper row buffer + |if ($fastRowBuffer != null) { + | $unsafeRowBuffer = $fastRowBuffer; + |} + |$updateRowInRegularHashMap + """.stripMargin + } + } else { + updateRowInRegularHashMap + } } val declareRowBuffer: String = if (isFastHashMapEnabled) { From 88ff7d74c3694bc9d92d833dd8507b8af6c7a21a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 22 Aug 2019 16:40:36 +0900 Subject: [PATCH 03/20] Give up splitting aggregate code if a parameter length goes over the JVM limit --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../aggregate/HashAggregateExec.scala | 204 +++++++++++------- .../execution/WholeStageCodegenSuite.scala | 31 +++ 3 files changed, 154 insertions(+), 83 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1897d920a6aa..8c8f78a2c2a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1645,7 +1645,7 @@ object CodeGenerator extends Logging { collectLocalVariable(isNull, ref) case ref: BoundReference => - argMap += JavaCode.variable(ctx.INPUT_ROW, ObjectType(classOf[InternalRow])) -> ref + argMap += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) -> ref case e => stack.pushAll(e.children) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 66bc06f8716d..b50ad0314302 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -264,17 +264,29 @@ case class HashAggregateExec( aggregateExpressions: Seq[Expression], aggEvalCodes: Seq[String], subExprs: Map[Expression, SubExprEliminationState], - bufferInput: Option[VariableValue] = None): Seq[String] = { - aggregateExpressions.zip(aggEvalCodes).map { case (aggExpr, aggEvalCode) => + bufferInput: Option[VariableValue] = None): Option[String] = { + val aggExprWithInputVars = aggregateExpressions.map { aggExpr => val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, aggExpr, subExprs) - val args = inputVars.map(_._1) ++ bufferInput.map(_ :: Nil).getOrElse(Nil) val paramLength = CodeGenerator.calculateParamLength(inputVars.map(_._2)) + (if (bufferInput.isDefined) 1 else 0) - // This method gives up splitting the code if the parameter length goes over the limit + // Checks if a parameter length for the `aggExpr` does not go over the JVM limit if (CodeGenerator.isValidParamLength(paramLength)) { + Some((aggExpr, inputVars)) + } else { + None + } + } + + // Checks if all the aggregate code can be split into pieces. + // If the parameter length of at lease one `aggExpr` goes over the limit, + // we totally give up splitting aggregate code. + if (aggExprWithInputVars.forall(_.isDefined)) { + val splitCodes = aggExprWithInputVars.flatten.zip(aggEvalCodes) + .map { case ((aggExpr, inputVars), aggEvalCode) => + val args = inputVars.map(_._1) ++ bufferInput.map(_ :: Nil).getOrElse(Nil) val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}") - val argList = args.map(v => s"${v.javaType} ${v.variableName}").mkString(", ") + val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") val doAggValFuncName = ctx.addNewFunction(doAggVal, s""" | private void $doAggVal($argList) throws java.io.IOException { @@ -284,8 +296,17 @@ case class HashAggregateExec( val inputVariables = args.map(_.variableName).mkString(", ") s"$doAggValFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n")) + } else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + + "length of at least one split function went over the JVM limit: " + + CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) } else { - aggEvalCode + logInfo(errMsg) + None } } } @@ -303,6 +324,32 @@ case class HashAggregateExec( } } + lazy val aggregateCodeInSingleFunc = { + ctx.currentVars = bufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + // aggregate buffer should be updated atomically + val updates = aggVals.zip(bufVars).map { case (ev, bufVar) => + s""" + | ${bufVar.isNull} = ${ev.isNull}; + | ${bufVar.value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin + } + if (conf.codegenSplitAggregateFunc) { // We need to copy the aggregation buffer to local variables first because each aggregate // function directly updates the buffer when it finishes. @@ -337,45 +384,26 @@ case class HashAggregateExec( """.stripMargin } - val updateAggValCode = splitAggregateExpressions( + splitAggregateExpressions( ctx = ctx, aggregateExpressions = boundUpdateExpr, aggEvalCodes = evalAndUpdateCodes, - subExprs = subExprs.states) + subExprs = subExprs.states).map { updateAggValCode => - s""" - | // do aggregate - | // copy aggregation buffer to the local - | $initLocalBufVar - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | ${updateAggValCode.mkString("\n")} - """.stripMargin - } else { - ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - // aggregate buffer should be updated atomically - val updates = aggVals.zip(bufVars).map { case (ev, bufVar) => s""" - | ${bufVar.isNull} = ${ev.isNull}; - | ${bufVar.value} = ${ev.value}; + | // do aggregate + | // copy aggregation buffer to the local + | $initLocalBufVar + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate functions and update aggregation buffers + | $updateAggValCode """.stripMargin + }.getOrElse { + aggregateCodeInSingleFunc } - s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} - """.stripMargin + } else { + aggregateCodeInSingleFunc } } @@ -911,6 +939,28 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input + lazy val aggregateCodeInSingleFunc = { + ctx.INPUT_ROW = unsafeRowBuffer + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate function + |${evaluateVariables(unsafeRowBufferEvals)} + |// update unsafe row buffer + |${updateUnsafeRowBuffer.mkString("\n").trim} + """.stripMargin + } + val updateRowInRegularHashMap: String = if (conf.codegenSplitAggregateFunc) { // We need to copy the aggregation row buffer to a local row first because each aggregate // function directly updates the buffer when it finishes. @@ -937,42 +987,28 @@ case class HashAggregateExec( """.stripMargin } - val updateAggValCode = splitAggregateExpressions( + splitAggregateExpressions( ctx = ctx, aggregateExpressions = boundUpdateExpr, aggEvalCodes = evalAndUpdateCodes, subExprs = subExprs.states, bufferInput = Some(VariableValue(unsafeRowBuffer, classOf[InternalRow]))) + .map { updateAggValCode => - s""" - | // do aggregate - | // copy aggregation row buffer to the local - | $initLocalRowBuffer - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | ${updateAggValCode.mkString("\n")} - """.stripMargin - } else { - ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + s""" + | // do aggregate + | // copy aggregation row buffer to the local + | $initLocalRowBuffer + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate functions and update aggregation buffers + | $updateAggValCode + """.stripMargin + }.getOrElse { + aggregateCodeInSingleFunc } - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin + } else { + aggregateCodeInSingleFunc } val updateRowInHashMap: String = if (conf.codegenSplitAggregateFunc) { @@ -1002,27 +1038,31 @@ case class HashAggregateExec( """.stripMargin } - val updateAggValCode = splitAggregateExpressions( + splitAggregateExpressions( ctx = ctx, aggregateExpressions = boundUpdateExpr, aggEvalCodes = evalAndUpdateCodes, subExprs = subExprs.states, bufferInput = Some(VariableValue(fastRowBuffer, classOf[InternalRow]))) + .map { updateAggValCode => - // If fast hash map is on, we first generate code to update row in fast hash map, if the - // previous loop up hit fast hash map. Otherwise, update row in regular hash map. - s""" - |if ($fastRowBuffer != null) { - | // copy aggregation row buffer to the local - | $initLocalRowBuffer - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | ${updateAggValCode.mkString("\n")} - |} else { - | $updateRowInRegularHashMap - |} - """.stripMargin + // If fast hash map is on, we first generate code to update row in fast hash map, if the + // previous loop up hit fast hash map. Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // copy aggregation row buffer to the local + | $initLocalRowBuffer + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate functions and update aggregation buffers + | $updateAggValCode + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + }.getOrElse { + updateRowInRegularHashMap + } } else { updateRowInRegularHashMap } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0ea16a1a15d6..92ff97aad652 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -398,4 +398,35 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { }.isDefined, "LocalTableScanExec should be within a WholeStageCodegen domain.") } + + test("Give up splitting aggregate code if a parameter length goes over the JVM limit") { + withSQLConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true") { + withTable("t") { + val numCols = 100 + val colExprs = "id AS key" +: (0 until numCols).map { i => s"id AS _c$i" } + spark.range(3).selectExpr(colExprs: _*).write.saveAsTable("t") + + // Defines too many common subexpressions for a parameter length + // to go over the JVM limit. + val aggExprs = (2 until numCols).map { i => + (0 until i).map(d => s"_c$d") + .mkString("SUM(", " + ", ")") + } + + // Test case without keys + var cause = intercept[Exception] { + sql(s"SELECT ${aggExprs.mkString(", ")} FROM t").collect + }.getCause + assert(cause.isInstanceOf[IllegalStateException]) + assert(cause.getMessage.contains("Failed to split aggregate code into small functions")) + + // Tet case with keys + cause = intercept[Exception] { + sql(s"SELECT key, ${aggExprs.mkString(", ")} FROM t GROUP BY key").collect + }.getCause + assert(cause.isInstanceOf[IllegalStateException]) + assert(cause.getMessage.contains("Failed to split aggregate code into small functions")) + } + } + } } From 860740ba00ba2b1a91bf3dca61574fcafe052fed Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 24 Aug 2019 09:34:47 +0900 Subject: [PATCH 04/20] Defines a split function for each aggregation expression --- .../expressions/codegen/CodeGenerator.scala | 40 +- .../aggregate/HashAggregateExec.scala | 471 +++++++++--------- .../execution/WholeStageCodegenSuite.scala | 7 +- 3 files changed, 269 insertions(+), 249 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 8c8f78a2c2a6..2ff11f7bf6ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1620,11 +1620,15 @@ object CodeGenerator extends Logging { def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState]): Seq[(VariableValue, Expression)] = { - val argMap = mutable.Map[VariableValue, Expression]() - // Collects local variables only in `argMap` - val collectLocalVariable = (ev: ExprValue, expr: Expression) => ev match { - case vv: VariableValue => argMap += vv -> expr + subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = { + val argSet = mutable.Set[VariableValue]() + if (ctx.INPUT_ROW != null) { + argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) + } + + // Collects local variables from a given `expr` tree + val collectLocalVariable = (ev: ExprValue) => ev match { + case vv: VariableValue => argSet += vv case _ => } @@ -1633,26 +1637,23 @@ object CodeGenerator extends Logging { stack.pop() match { case e if subExprs.contains(e) => val SubExprEliminationState(isNull, value) = subExprs(e) - collectLocalVariable(value, e) - collectLocalVariable(isNull, e) + collectLocalVariable(value) + collectLocalVariable(isNull) // Since the children possibly have common subexprs, we push them here stack.pushAll(e.children) - case ref: BoundReference - if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + case ref: BoundReference if ctx.currentVars != null && + ctx.currentVars(ref.ordinal) != null => val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal) - collectLocalVariable(value, ref) - collectLocalVariable(isNull, ref) - - case ref: BoundReference => - argMap += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) -> ref + collectLocalVariable(value) + collectLocalVariable(isNull) case e => stack.pushAll(e.children) } } - argMap.toSeq + argSet.toSet } /** @@ -1762,6 +1763,15 @@ object CodeGenerator extends Logging { 1 + params.map(paramLengthForExpr).sum } + def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = { + def paramLengthForExpr(input: ExprValue): Int = input.javaType match { + case java.lang.Long.TYPE | java.lang.Double.TYPE => 2 + case _ => 1 + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + /** * In Java, a method descriptor is valid only if it represents method parameters with a total * length less than a pre-defined constant. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b50ad0314302..2da7fbfe73a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.aggregate import java.util.concurrent.TimeUnit._ +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rdd.RDD @@ -36,7 +38,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, DecimalType, ObjectType, StringType, StructType} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -175,7 +177,7 @@ case class HashAggregateExec( } // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + private var bufVars: Seq[Seq[ExprCode]] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -184,27 +186,30 @@ case class HashAggregateExec( // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val initExpr = functions.flatMap(f => f.initialValues) - bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") - // The initial expression should not access any column - val ev = e.genCode(ctx) - val initVars = code""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - """.stripMargin - ExprCode( - ev.code + initVars, - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, e.dataType)) + val initExpr = functions.map(f => f.initialValues) + bufVars = initExpr.map { exprs => + exprs.map { e => + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") + // The initial expression should not access any column + val ev = e.genCode(ctx) + val initVars = code""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) + } } - val initBufVar = evaluateVariables(bufVars) + val flatBufVars = bufVars.flatten + val initBufVar = evaluateVariables(flatBufVars) // generate variables for output val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results - ctx.currentVars = bufVars + ctx.currentVars = flatBufVars val aggResults = bindReferences( functions.map(_.evaluateExpression), aggregateBufferAttributes).map(_.genCode(ctx)) @@ -218,7 +223,7 @@ case class HashAggregateExec( """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly - (bufVars, "") + (flatBufVars, "") } else { // no aggregate function, the result should be literals val resultVars = resultExpressions.map(_.genCode(ctx)) @@ -256,48 +261,50 @@ case class HashAggregateExec( } // Splits aggregate code into small functions because the most of JVM implementations - // can not compile too long functions. Note that different from `CodeGenerator.splitExpressions`, - // we will extract input variables from references and subexpression elimination states - // for each aggregate expression, then pass them to it. + // can not compile too long functions. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. private def splitAggregateExpressions( ctx: CodegenContext, - aggregateExpressions: Seq[Expression], - aggEvalCodes: Seq[String], - subExprs: Map[Expression, SubExprEliminationState], - bufferInput: Option[VariableValue] = None): Option[String] = { - val aggExprWithInputVars = aggregateExpressions.map { aggExpr => - val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, aggExpr, subExprs) - val paramLength = CodeGenerator.calculateParamLength(inputVars.map(_._2)) + - (if (bufferInput.isDefined) 1 else 0) - - // Checks if a parameter length for the `aggExpr` does not go over the JVM limit + aggNames: Seq[String], + aggExprs: Seq[Seq[Expression]], + makeSplitAggFunctions: => Seq[String], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { + val inputVars = aggExprs.map { aggExprsInAgg => + val inputVarsInAgg = aggExprsInAgg.map( + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg) + + // Checks if a parameter length for the `aggExprsInAgg` does not go over the JVM limit if (CodeGenerator.isValidParamLength(paramLength)) { - Some((aggExpr, inputVars)) + Some(inputVarsInAgg) } else { None } } // Checks if all the aggregate code can be split into pieces. - // If the parameter length of at lease one `aggExpr` goes over the limit, + // If the parameter length of at lease one `aggExprsInAgg` goes over the limit, // we totally give up splitting aggregate code. - if (aggExprWithInputVars.forall(_.isDefined)) { - val splitCodes = aggExprWithInputVars.flatten.zip(aggEvalCodes) - .map { case ((aggExpr, inputVars), aggEvalCode) => - val args = inputVars.map(_._1) ++ bufferInput.map(_ :: Nil).getOrElse(Nil) - val doAggVal = ctx.freshName(s"doAggregateVal_${aggExpr.prettyName}") + if (inputVars.forall(_.isDefined)) { + val splitAggEvalCodes = makeSplitAggFunctions + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => + val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") val doAggValFuncName = ctx.addNewFunction(doAggVal, s""" | private void $doAggVal($argList) throws java.io.IOException { - | $aggEvalCode + | ${splitAggEvalCodes(i)} | } """.stripMargin) val inputVariables = args.map(_.variableName).mkString(", ") s"$doAggValFuncName($inputVariables);" } - Some(splitCodes.mkString("\n")) + Some(splitCodes.mkString("\n").trim) } else { val errMsg = "Failed to split aggregate code into small functions because the parameter " + "length of at least one split function went over the JVM limit: " + @@ -315,7 +322,7 @@ case class HashAggregateExec( // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => + val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions @@ -323,17 +330,21 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - - lazy val aggregateCodeInSingleFunc = { - ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + ctx.currentVars = bufVars.flatten ++ input + val boundUpdateExprs = updateExprs.map { updateExprsInAgg => + updateExprsInAgg.map(BindReferences.bindReference(_, inputAttrs)) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = boundUpdateExprs.map { boundUpdateExprsInAgg => + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsInAgg.map(_.genCode(ctx)) } + } + + lazy val nonSplitAggCode = { // aggregate buffer should be updated atomically - val updates = aggVals.zip(bufVars).map { case (ev, bufVar) => + val updates = aggVals.flatten.zip(bufVars.flatten).map { case (ev, bufVar) => s""" | ${bufVar.isNull} = ${ev.isNull}; | ${bufVar.value} = ${ev.value}; @@ -343,67 +354,53 @@ case class HashAggregateExec( | // do aggregate | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer + | // evaluate aggregate functions + | ${evaluateVariables(aggVals.flatten)} + | // update aggregation buffers | ${updates.mkString("\n").trim} """.stripMargin } if (conf.codegenSplitAggregateFunc) { - // We need to copy the aggregation buffer to local variables first because each aggregate - // function directly updates the buffer when it finishes. - val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) => - val isNull = ctx.freshName("localBufIsNull") - val value = ctx.freshName("localBufValue") - val initLocalVars = code""" - | boolean $isNull = ${ev.isNull}; - | ${CodeGenerator.javaType(e.dataType)} $value = ${ev.value}; - """.stripMargin - ExprCode(initLocalVars, JavaCode.isNullVariable(isNull), - JavaCode.variable(value, e.dataType)) - } - - val initLocalBufVar = evaluateVariables(localBufVars) - - ctx.currentVars = localBufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - - val evalAndUpdateCodes = aggVals.zip(bufVars).map { case (ev, bufVar) => - s""" - | // evaluate aggregate function - | ${ev.code} - | // update aggregation buffer - | ${bufVar.isNull} = ${ev.isNull}; - | ${bufVar.value} = ${ev.value}; - """.stripMargin - } - - splitAggregateExpressions( + val splitAggCode = splitAggregateExpressions( ctx = ctx, - aggregateExpressions = boundUpdateExpr, - aggEvalCodes = evalAndUpdateCodes, - subExprs = subExprs.states).map { updateAggValCode => + aggNames = functions.map(_.prettyName), + aggExprs = boundUpdateExprs, + makeSplitAggFunctions = { + aggVals.zip(bufVars).map { case (aggValsInAgg, bufVarsInAgg) => + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updates = aggValsInAgg.zip(bufVarsInAgg).map { case (ev, bufVar) => + s""" + | ${bufVar.isNull} = ${ev.isNull}; + | ${bufVar.value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | // evaluate aggregate function + | ${evaluateVariables(aggValsInAgg)} + | // update aggregation buffers + | ${updates.mkString("\n").trim} + """.stripMargin + } + }, + subExprs = subExprs.states + ) + splitAggCode.map { updateAggValCode => s""" | // do aggregate - | // copy aggregation buffer to the local - | $initLocalBufVar | // common sub-expressions | $effectiveCodes | // evaluate aggregate functions and update aggregation buffers | $updateAggValCode """.stripMargin }.getOrElse { - aggregateCodeInSingleFunc + nonSplitAggCode } } else { - aggregateCodeInSingleFunc + nonSplitAggCode } } @@ -861,7 +858,7 @@ case class HashAggregateExec( val fastRowBuffer = ctx.freshName("fastAggBuffer") // only have DeclarativeAggregate - val updateExpr = aggregateExpressions.flatMap { e => + val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions @@ -939,164 +936,178 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - lazy val aggregateCodeInSingleFunc = { - ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + // Computes buffer offsets for split functions in the underlying buffer row + lazy val bufferOffsets = { + val offsets = mutable.ArrayBuffer[Int]() + var curOffset = 0 + updateExprs.foreach { exprsInAgg => + offsets += curOffset + curOffset += exprsInAgg.length } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) - } - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin + offsets.toArray } - val updateRowInRegularHashMap: String = if (conf.codegenSplitAggregateFunc) { - // We need to copy the aggregation row buffer to a local row first because each aggregate - // function directly updates the buffer when it finishes. - val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") - val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" - - ctx.INPUT_ROW = localRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + val updateRowInRegularHashMap: String = { + ctx.INPUT_ROW = unsafeRowBuffer + val boundUpdateExprs = updateExprs.map { updateExprsInAgg => + bindReferences(updateExprsInAgg, inputAttr) } - - val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - val updateColumnCode = - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) - s""" - | // evaluate aggregate function - | ${ev.code} - | // update unsafe row buffer - | $updateColumnCode - """.stripMargin + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsInAgg => + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsInAgg.map(_.genCode(ctx)) + } } - splitAggregateExpressions( - ctx = ctx, - aggregateExpressions = boundUpdateExpr, - aggEvalCodes = evalAndUpdateCodes, - subExprs = subExprs.states, - bufferInput = Some(VariableValue(unsafeRowBuffer, classOf[InternalRow]))) - .map { updateAggValCode => - + lazy val nonSplitAggCode = { + // aggregate buffer should be updated atomically + val flatBoundUpdateExprs = boundUpdateExprs.flatten + val updateUnsafeRowBuffer = unsafeRowBufferEvals.flatten.zipWithIndex.map { case (ev, i) => + val dt = flatBoundUpdateExprs(i).dataType + val nullable = flatBoundUpdateExprs(i).nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, nullable) + } s""" - | // do aggregate - | // copy aggregation row buffer to the local - | $initLocalRowBuffer - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | $updateAggValCode + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate function + |${evaluateVariables(unsafeRowBufferEvals.flatten)} + |// update unsafe row buffer + |${updateUnsafeRowBuffer.mkString("\n").trim} """.stripMargin - }.getOrElse { - aggregateCodeInSingleFunc } - } else { - aggregateCodeInSingleFunc - } - val updateRowInHashMap: String = if (conf.codegenSplitAggregateFunc) { - if (isFastHashMapEnabled) { - // We need to copy the aggregation row buffer to a local row first because each aggregate - // function directly updates the buffer when it finishes. - val localRowBuffer = ctx.freshName("localFastRowBuffer") - val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" - - ctx.INPUT_ROW = localRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - - val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - val updateColumnCode = CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) - s""" - | // evaluate aggregate function - | ${ev.code} - | // update fast row - | $updateColumnCode - """.stripMargin - } - - splitAggregateExpressions( + if (conf.codegenSplitAggregateFunc) { + val splitAggCode = splitAggregateExpressions( ctx = ctx, - aggregateExpressions = boundUpdateExpr, - aggEvalCodes = evalAndUpdateCodes, - subExprs = subExprs.states, - bufferInput = Some(VariableValue(fastRowBuffer, classOf[InternalRow]))) - .map { updateAggValCode => - - // If fast hash map is on, we first generate code to update row in fast hash map, if the - // previous loop up hit fast hash map. Otherwise, update row in regular hash map. + aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName), + aggExprs = boundUpdateExprs, + makeSplitAggFunctions = { + unsafeRowBufferEvals.zipWithIndex.map { case (rowBufferEvalsInAgg, i) => + val boundUpdateExprsInAgg = boundUpdateExprs(i) + val bufferOffset = bufferOffsets(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updateRowBuffers = rowBufferEvalsInAgg.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsInAgg(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) + } + s""" + |// evaluate aggregate function + |${evaluateVariables(rowBufferEvalsInAgg)} + |// update unsafe row buffer + |${updateRowBuffers.mkString("\n").trim} + """.stripMargin + } + }, + subExprs = subExprs.states) + + splitAggCode.map { updateAggValCode => s""" - |if ($fastRowBuffer != null) { - | // copy aggregation row buffer to the local - | $initLocalRowBuffer - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | $updateAggValCode - |} else { - | $updateRowInRegularHashMap - |} + | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate functions and update aggregation buffers + | $updateAggValCode """.stripMargin }.getOrElse { - updateRowInRegularHashMap + nonSplitAggCode } } else { - updateRowInRegularHashMap + nonSplitAggCode } - } else { + } + + val updateRowInHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsInAgg => + updateExprsInAgg.map(BindReferences.bindReference(_, inputAttr)) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsInAgg => + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsInAgg.map(_.genCode(ctx)) + } } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true) + + lazy val nonSplitAggCode = { + val flatBoundUpdateExprs = boundUpdateExprs.flatten + val updateFastRow = fastRowEvals.flatten.zipWithIndex.map { case (ev, i) => + val dt = flatBoundUpdateExprs(i).dataType + CodeGenerator.updateColumn( + fastRowBuffer, dt, i, ev, flatBoundUpdateExprs(i).nullable, isVectorized = true) + } + + // If vectorized fast hash map is on, we first generate code to update row + // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. + // Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(fastRowEvals.flatten)} + | // update fast row + | ${updateFastRow.mkString("\n").trim} + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin } - // If vectorized fast hash map is on, we first generate code to update row - // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. - // Otherwise, update row in regular hash map. - s""" - |if ($fastRowBuffer != null) { - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} - |} else { - | $updateRowInRegularHashMap - |} - """.stripMargin + if (conf.codegenSplitAggregateFunc) { + val splitAggCode = splitAggregateExpressions( + ctx = ctx, + aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName), + aggExprs = boundUpdateExprs, + makeSplitAggFunctions = { + fastRowEvals.zipWithIndex.map { case (fastRowEvalsInAgg, i) => + val boundUpdateExprsInAgg = boundUpdateExprs(i) + val bufferOffset = bufferOffsets(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updateRowBuffer = fastRowEvalsInAgg.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsInAgg(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, + isVectorized = true) + } + s""" + | // evaluate aggregate function + | ${evaluateVariables(fastRowEvalsInAgg)} + | // update fast row + | ${updateRowBuffer.mkString("\n").trim} + """.stripMargin + } + }, + subExprs = subExprs.states) + + splitAggCode.map { updateAggValCode => + // If fast hash map is on, we first generate code to update row in fast hash map, if + // the previous loop up hit fast hash map. Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate functions and update aggregation buffers + | $updateAggValCode + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + }.getOrElse { + nonSplitAggCode + } + } else { + nonSplitAggCode + } } else { // If row-based hash map is on and the previous loop up hit fast hash map, // we reuse regular hash buffer to update row of fast hash map. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 92ff97aad652..372848bb9ce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -402,15 +402,14 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { test("Give up splitting aggregate code if a parameter length goes over the JVM limit") { withSQLConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true") { withTable("t") { - val numCols = 100 + val numCols = 45 val colExprs = "id AS key" +: (0 until numCols).map { i => s"id AS _c$i" } spark.range(3).selectExpr(colExprs: _*).write.saveAsTable("t") - // Defines too many common subexpressions for a parameter length + // Defines many common subexpressions for a parameter length // to go over the JVM limit. val aggExprs = (2 until numCols).map { i => - (0 until i).map(d => s"_c$d") - .mkString("SUM(", " + ", ")") + (0 until i).map(d => s"_c$d").mkString("SUM(", " + ", ")") } // Test case without keys From 4a42e35e4cc395ec11a9df840015ad60768fb121 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 28 Aug 2019 13:48:45 +0900 Subject: [PATCH 05/20] Makes splitAggregateFunc enabled by default for testing --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 85e8763bdb60..4c425f25f488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1054,7 +1054,7 @@ object SQLConf { "instead of a single big method. This can be used to avoid oversized function that " + "can miss the opportunity of JIT optimization.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val MAX_NESTED_VIEW_DEPTH = buildConf("spark.sql.view.maxNestedViewDepth") From 3ede09dc6339ea300ccf8568cc2da33729819a4d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 28 Aug 2019 22:57:55 +0900 Subject: [PATCH 06/20] Split aggregate code if the code length goes over the threshold --- .../expressions/codegen/javaCode.scala | 5 +- .../aggregate/HashAggregateExec.scala | 190 ++++++++---------- .../sql-tests/inputs/group-analytics.sql | 3 - .../sql-tests/inputs/group-by-ordinal.sql | 3 - .../resources/sql-tests/inputs/group-by.sql | 3 - .../sql-tests/inputs/grouping_set.sql | 3 - 6 files changed, 83 insertions(+), 124 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 3bb3c602f775..d9393b9df6bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode { case _ => code.trim } - def length: Int = toString.length + def length: Int = { + // Returns a code length without comments + CodeFormatter.stripExtraNewLinesAndComments(toString).length + } def isEmpty: Boolean = toString.isEmpty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2da7fbfe73a6..c8202ffa399d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -271,7 +271,7 @@ case class HashAggregateExec( ctx: CodegenContext, aggNames: Seq[String], aggExprs: Seq[Seq[Expression]], - makeSplitAggFunctions: => Seq[String], + aggCodeBlocks: Seq[Block], subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { val inputVars = aggExprs.map { aggExprsInAgg => val inputVarsInAgg = aggExprsInAgg.map( @@ -290,14 +290,13 @@ case class HashAggregateExec( // If the parameter length of at lease one `aggExprsInAgg` goes over the limit, // we totally give up splitting aggregate code. if (inputVars.forall(_.isDefined)) { - val splitAggEvalCodes = makeSplitAggFunctions val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") val doAggValFuncName = ctx.addNewFunction(doAggVal, s""" | private void $doAggVal($argList) throws java.io.IOException { - | ${splitAggEvalCodes(i)} + | ${aggCodeBlocks(i)} | } """.stripMargin) @@ -322,6 +321,8 @@ case class HashAggregateExec( // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + // To individually generate code for each aggregate function, an element in `updateExprs` holds + // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => @@ -342,51 +343,40 @@ case class HashAggregateExec( } } - lazy val nonSplitAggCode = { - // aggregate buffer should be updated atomically - val updates = aggVals.flatten.zip(bufVars.flatten).map { case (ev, bufVar) => + val aggNames = functions.map(_.prettyName) + val aggCodeBlocks = aggVals.zipWithIndex.map { case (aggValsInAgg, i) => + val bufVarsInAgg = bufVars(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updates = aggValsInAgg.zip(bufVarsInAgg).map { case (ev, bufVar) => s""" | ${bufVar.isNull} = ${ev.isNull}; | ${bufVar.value} = ${ev.value}; """.stripMargin } - s""" + code""" + | // do aggregate for ${aggNames(i)} + | // evaluate aggregate function + | ${evaluateVariables(aggValsInAgg)} + | // update aggregation buffers + | ${updates.mkString("\n").trim} + """.stripMargin + } + + lazy val nonSplitAggCode = { + s""" | // do aggregate | // common sub-expressions | $effectiveCodes - | // evaluate aggregate functions - | ${evaluateVariables(aggVals.flatten)} - | // update aggregation buffers - | ${updates.mkString("\n").trim} + | // evaluate aggregate functions and update aggregation buffers + | ${aggCodeBlocks.fold(EmptyBlock)(_ + _)} """.stripMargin } - if (conf.codegenSplitAggregateFunc) { + if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val splitAggCode = splitAggregateExpressions( - ctx = ctx, - aggNames = functions.map(_.prettyName), - aggExprs = boundUpdateExprs, - makeSplitAggFunctions = { - aggVals.zip(bufVars).map { case (aggValsInAgg, bufVarsInAgg) => - // All the update code for aggregation buffers should be placed in the end - // of each aggregation function code. - val updates = aggValsInAgg.zip(bufVarsInAgg).map { case (ev, bufVar) => - s""" - | ${bufVar.isNull} = ${ev.isNull}; - | ${bufVar.value} = ${ev.value}; - """.stripMargin - } - s""" - | // do aggregate - | // evaluate aggregate function - | ${evaluateVariables(aggValsInAgg)} - | // update aggregation buffers - | ${updates.mkString("\n").trim} - """.stripMargin - } - }, - subExprs = subExprs.states - ) + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) splitAggCode.map { updateAggValCode => s""" @@ -936,8 +926,10 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - // Computes buffer offsets for split functions in the underlying buffer row - lazy val bufferOffsets = { + val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) + // Computes buffer offsets for each aggregation function code + // in the underlying buffer row. + val bufferOffsets = { val offsets = mutable.ArrayBuffer[Int]() var curOffset = 0 updateExprs.foreach { exprsInAgg => @@ -960,50 +952,38 @@ case class HashAggregateExec( } } - lazy val nonSplitAggCode = { - // aggregate buffer should be updated atomically - val flatBoundUpdateExprs = boundUpdateExprs.flatten - val updateUnsafeRowBuffer = unsafeRowBufferEvals.flatten.zipWithIndex.map { case (ev, i) => - val dt = flatBoundUpdateExprs(i).dataType - val nullable = flatBoundUpdateExprs(i).nullable - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, nullable) + val aggCodeBlocks = unsafeRowBufferEvals.zipWithIndex.map { case (rowBufferEvalsInAgg, i) => + val boundUpdateExprsInAgg = boundUpdateExprs(i) + val bufferOffset = bufferOffsets(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updateRowBuffers = rowBufferEvalsInAgg.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsInAgg(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) } + code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(rowBufferEvalsInAgg)} + |// update unsafe row buffer + |${updateRowBuffers.mkString("\n").trim} + """.stripMargin + } + + lazy val nonSplitAggCode = { s""" |// common sub-expressions |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals.flatten)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} + | // evaluate aggregate functions and update aggregation buffers + | ${aggCodeBlocks.fold(EmptyBlock)(_ + _)} """.stripMargin } - if (conf.codegenSplitAggregateFunc) { + if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val splitAggCode = splitAggregateExpressions( - ctx = ctx, - aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName), - aggExprs = boundUpdateExprs, - makeSplitAggFunctions = { - unsafeRowBufferEvals.zipWithIndex.map { case (rowBufferEvalsInAgg, i) => - val boundUpdateExprsInAgg = boundUpdateExprs(i) - val bufferOffset = bufferOffsets(i) - // All the update code for aggregation buffers should be placed in the end - // of each aggregation function code. - val updateRowBuffers = rowBufferEvalsInAgg.zipWithIndex.map { case (ev, j) => - val updateExpr = boundUpdateExprsInAgg(j) - val dt = updateExpr.dataType - val nullable = updateExpr.nullable - CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) - } - s""" - |// evaluate aggregate function - |${evaluateVariables(rowBufferEvalsInAgg)} - |// update unsafe row buffer - |${updateRowBuffers.mkString("\n").trim} - """.stripMargin - } - }, - subExprs = subExprs.states) + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) splitAggCode.map { updateAggValCode => s""" @@ -1036,14 +1016,27 @@ case class HashAggregateExec( } } - lazy val nonSplitAggCode = { - val flatBoundUpdateExprs = boundUpdateExprs.flatten - val updateFastRow = fastRowEvals.flatten.zipWithIndex.map { case (ev, i) => - val dt = flatBoundUpdateExprs(i).dataType - CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, flatBoundUpdateExprs(i).nullable, isVectorized = true) - } + val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsInAgg, i) => + val boundUpdateExprsInAgg = boundUpdateExprs(i) + val bufferOffset = bufferOffsets(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updateRowBuffer = fastRowEvalsInAgg.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsInAgg(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, + isVectorized = true) + } + code""" + | // evaluate aggregate function for ${aggNames(i)} + | ${evaluateVariables(fastRowEvalsInAgg)} + | // update fast row + | ${updateRowBuffer.mkString("\n").trim} + """.stripMargin + } + lazy val nonSplitAggCode = { // If vectorized fast hash map is on, we first generate code to update row // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. // Otherwise, update row in regular hash map. @@ -1051,43 +1044,18 @@ case class HashAggregateExec( |if ($fastRowBuffer != null) { | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals.flatten)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} + | // evaluate aggregate functions and update aggregation buffers + | ${aggCodeBlocks.fold(EmptyBlock)(_ + _)} |} else { | $updateRowInRegularHashMap |} """.stripMargin } - if (conf.codegenSplitAggregateFunc) { + if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val splitAggCode = splitAggregateExpressions( - ctx = ctx, - aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName), - aggExprs = boundUpdateExprs, - makeSplitAggFunctions = { - fastRowEvals.zipWithIndex.map { case (fastRowEvalsInAgg, i) => - val boundUpdateExprsInAgg = boundUpdateExprs(i) - val bufferOffset = bufferOffsets(i) - // All the update code for aggregation buffers should be placed in the end - // of each aggregation function code. - val updateRowBuffer = fastRowEvalsInAgg.zipWithIndex.map { case (ev, j) => - val updateExpr = boundUpdateExprsInAgg(j) - val dt = updateExpr.dataType - val nullable = updateExpr.nullable - CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, - isVectorized = true) - } - s""" - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvalsInAgg)} - | // update fast row - | ${updateRowBuffer.mkString("\n").trim} - """.stripMargin - } - }, - subExprs = subExprs.states) + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) splitAggCode.map { updateAggValCode => // If fast hash map is on, we first generate code to update row in fast hash map, if diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index d102512646a3..9721f8c60ebc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -1,6 +1,3 @@ ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false - CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2) AS testData(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 7321a5b30f7d..3144833b608b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -1,6 +1,3 @@ ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false - -- group by ordinal positions create temporary view data as select * from values diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 5d63825b2b2a..66bc90914e0d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -1,6 +1,3 @@ ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false - -- Test data. CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES (1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index d9f9c9133257..6bbde9f38d65 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -1,6 +1,3 @@ ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=true ---SET spark.sql.codegen.aggregate.splitAggregateFunc.enabled=false - CREATE TEMPORARY VIEW grouping AS SELECT * FROM VALUES ("1", "2", "3", 1), ("4", "5", "6", 1), From 5cc844e17babd577bc237a5b7e18920eff668408 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 29 Aug 2019 08:48:03 +0900 Subject: [PATCH 07/20] Address reviews --- .../aggregate/HashAggregateExec.scala | 133 +++++++++--------- 1 file changed, 67 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c8202ffa399d..f2e8eecb1837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,8 +194,8 @@ case class HashAggregateExec( // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = code""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + |$isNull = ${ev.isNull}; + |$value = ${ev.value}; """.stripMargin ExprCode( ev.code + initVars, @@ -270,24 +270,24 @@ case class HashAggregateExec( private def splitAggregateExpressions( ctx: CodegenContext, aggNames: Seq[String], - aggExprs: Seq[Seq[Expression]], + aggBufferUpdatingExprs: Seq[Seq[Expression]], aggCodeBlocks: Seq[Block], subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { - val inputVars = aggExprs.map { aggExprsInAgg => - val inputVarsInAgg = aggExprsInAgg.map( + val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => + val inputVarsForOneFunc = aggExprsForOneFunc.map( CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq - val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsInAgg) + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) - // Checks if a parameter length for the `aggExprsInAgg` does not go over the JVM limit + // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit if (CodeGenerator.isValidParamLength(paramLength)) { - Some(inputVarsInAgg) + Some(inputVarsForOneFunc) } else { None } } // Checks if all the aggregate code can be split into pieces. - // If the parameter length of at lease one `aggExprsInAgg` goes over the limit, + // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, // we totally give up splitting aggregate code. if (inputVars.forall(_.isDefined)) { val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => @@ -295,9 +295,9 @@ case class HashAggregateExec( val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") val doAggValFuncName = ctx.addNewFunction(doAggVal, s""" - | private void $doAggVal($argList) throws java.io.IOException { - | ${aggCodeBlocks(i)} - | } + |private void $doAggVal($argList) throws java.io.IOException { + | ${aggCodeBlocks(i)} + |} """.stripMargin) val inputVariables = args.map(_.variableName).mkString(", ") @@ -332,44 +332,44 @@ case class HashAggregateExec( } } ctx.currentVars = bufVars.flatten ++ input - val boundUpdateExprs = updateExprs.map { updateExprsInAgg => - updateExprsInAgg.map(BindReferences.bindReference(_, inputAttrs)) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = boundUpdateExprs.map { boundUpdateExprsInAgg => + val aggVals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsInAgg.map(_.genCode(ctx)) + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } } val aggNames = functions.map(_.prettyName) - val aggCodeBlocks = aggVals.zipWithIndex.map { case (aggValsInAgg, i) => - val bufVarsInAgg = bufVars(i) + val aggCodeBlocks = aggVals.zipWithIndex.map { case (aggValsForOneFunc, i) => + val bufVarsForOneFunc = bufVars(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. - val updates = aggValsInAgg.zip(bufVarsInAgg).map { case (ev, bufVar) => + val updates = aggValsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) => s""" - | ${bufVar.isNull} = ${ev.isNull}; - | ${bufVar.value} = ${ev.value}; + |${bufVar.isNull} = ${ev.isNull}; + |${bufVar.value} = ${ev.value}; """.stripMargin } code""" - | // do aggregate for ${aggNames(i)} - | // evaluate aggregate function - | ${evaluateVariables(aggValsInAgg)} - | // update aggregation buffers - | ${updates.mkString("\n").trim} + |// do aggregate for ${aggNames(i)} + |// evaluate aggregate function + |${evaluateVariables(aggValsForOneFunc)} + |// update aggregation buffers + |${updates.mkString("\n").trim} """.stripMargin } lazy val nonSplitAggCode = { s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | ${aggCodeBlocks.fold(EmptyBlock)(_ + _)} + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} """.stripMargin } @@ -380,11 +380,11 @@ case class HashAggregateExec( splitAggCode.map { updateAggValCode => s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | $updateAggValCode + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$updateAggValCode """.stripMargin }.getOrElse { nonSplitAggCode @@ -932,40 +932,41 @@ case class HashAggregateExec( val bufferOffsets = { val offsets = mutable.ArrayBuffer[Int]() var curOffset = 0 - updateExprs.foreach { exprsInAgg => + updateExprs.foreach { exprsForOneFunc => offsets += curOffset - curOffset += exprsInAgg.length + curOffset += exprsForOneFunc.length } offsets.toArray } val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExprs = updateExprs.map { updateExprsInAgg => - bindReferences(updateExprsInAgg, inputAttr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttr) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsInAgg => + val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsInAgg.map(_.genCode(ctx)) + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } } - val aggCodeBlocks = unsafeRowBufferEvals.zipWithIndex.map { case (rowBufferEvalsInAgg, i) => - val boundUpdateExprsInAgg = boundUpdateExprs(i) + val aggCodeBlocks = unsafeRowBufferEvals.zipWithIndex + .map { case (rowBufferEvalsForOneFunc, i) => + val boundUpdateExprsForOneFunc = boundUpdateExprs(i) val bufferOffset = bufferOffsets(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. - val updateRowBuffers = rowBufferEvalsInAgg.zipWithIndex.map { case (ev, j) => - val updateExpr = boundUpdateExprsInAgg(j) + val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) val dt = updateExpr.dataType val nullable = updateExpr.nullable CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) } code""" |// evaluate aggregate function for ${aggNames(i)} - |${evaluateVariables(rowBufferEvalsInAgg)} + |${evaluateVariables(rowBufferEvalsForOneFunc)} |// update unsafe row buffer |${updateRowBuffers.mkString("\n").trim} """.stripMargin @@ -975,8 +976,8 @@ case class HashAggregateExec( s""" |// common sub-expressions |$effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | ${aggCodeBlocks.fold(EmptyBlock)(_ + _)} + |// evaluate aggregate functions and update aggregation buffers + |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} """.stripMargin } @@ -987,11 +988,11 @@ case class HashAggregateExec( splitAggCode.map { updateAggValCode => s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | $updateAggValCode + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$updateAggValCode """.stripMargin }.getOrElse { nonSplitAggCode @@ -1005,34 +1006,34 @@ case class HashAggregateExec( if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExprs = updateExprs.map { updateExprsInAgg => - updateExprsInAgg.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttr) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsInAgg => + val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsInAgg.map(_.genCode(ctx)) + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } } - val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsInAgg, i) => - val boundUpdateExprsInAgg = boundUpdateExprs(i) + val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) => + val boundUpdateExprsForOneFunc = boundUpdateExprs(i) val bufferOffset = bufferOffsets(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. - val updateRowBuffer = fastRowEvalsInAgg.zipWithIndex.map { case (ev, j) => - val updateExpr = boundUpdateExprsInAgg(j) + val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) val dt = updateExpr.dataType val nullable = updateExpr.nullable CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, isVectorized = true) } code""" - | // evaluate aggregate function for ${aggNames(i)} - | ${evaluateVariables(fastRowEvalsInAgg)} - | // update fast row - | ${updateRowBuffer.mkString("\n").trim} + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(fastRowEvalsForOneFunc)} + |// update fast row + |${updateRowBuffer.mkString("\n").trim} """.stripMargin } From 2100a2e2a4cd714c7245ed677a512f2194edbfe1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 29 Aug 2019 13:46:14 +0900 Subject: [PATCH 08/20] Drop unnecessary ones from input variables --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2ff11f7bf6ac..4c1bfcfdf7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1639,8 +1639,6 @@ object CodeGenerator extends Logging { val SubExprEliminationState(isNull, value) = subExprs(e) collectLocalVariable(value) collectLocalVariable(isNull) - // Since the children possibly have common subexprs, we push them here - stack.pushAll(e.children) case ref: BoundReference if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => From 511594a70dedbf1c60ec8c0b12b055ba8ac556ad Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 29 Aug 2019 16:22:32 +0900 Subject: [PATCH 09/20] Brush up tests --- .../aggregate/HashAggregateExec.scala | 9 ++++- .../execution/WholeStageCodegenSuite.scala | 39 +++++++------------ 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f2e8eecb1837..a6c527275899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -260,6 +260,13 @@ case class HashAggregateExec( """.stripMargin } + private def isValidParamLength(paramLength: Int): Boolean = { + sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { + case null | "" => CodeGenerator.isValidParamLength(paramLength) + case validLength => paramLength <= validLength.toInt + } + } + // Splits aggregate code into small functions because the most of JVM implementations // can not compile too long functions. // @@ -279,7 +286,7 @@ case class HashAggregateExec( val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit - if (CodeGenerator.isValidParamLength(paramLength)) { + if (isValidParamLength(paramLength)) { Some(inputVarsForOneFunc) } else { None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 372848bb9ce5..eaad58844c19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -399,32 +399,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { "LocalTableScanExec should be within a WholeStageCodegen domain.") } - test("Give up splitting aggregate code if a parameter length goes over the JVM limit") { - withSQLConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true") { + test("Give up splitting aggregate code if a parameter length goes over the limit") { + withSQLConf( + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.HashAggregateExec.isValidParamLength" -> "0") { withTable("t") { - val numCols = 45 - val colExprs = "id AS key" +: (0 until numCols).map { i => s"id AS _c$i" } - spark.range(3).selectExpr(colExprs: _*).write.saveAsTable("t") - - // Defines many common subexpressions for a parameter length - // to go over the JVM limit. - val aggExprs = (2 until numCols).map { i => - (0 until i).map(d => s"_c$d").mkString("SUM(", " + ", ")") + val expectedErrMsg = "Failed to split aggregate code into small functions" + Seq( + // Test case without keys + "SELECT AVG(v) FROM VALUES(1) t(v)", + // Tet case with keys + "SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query => + val errMsg = intercept[IllegalStateException] { + sql(query).collect + }.getMessage + assert(errMsg.contains(expectedErrMsg)) } - - // Test case without keys - var cause = intercept[Exception] { - sql(s"SELECT ${aggExprs.mkString(", ")} FROM t").collect - }.getCause - assert(cause.isInstanceOf[IllegalStateException]) - assert(cause.getMessage.contains("Failed to split aggregate code into small functions")) - - // Tet case with keys - cause = intercept[Exception] { - sql(s"SELECT key, ${aggExprs.mkString(", ")} FROM t GROUP BY key").collect - }.getCause - assert(cause.isInstanceOf[IllegalStateException]) - assert(cause.getMessage.contains("Failed to split aggregate code into small functions")) } } } From 7ced1286815b726a015cc0993b95da375915926c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 30 Aug 2019 07:44:26 +0900 Subject: [PATCH 10/20] Fix a test error --- .../main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 796043fff665..8ba25a7a4782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -115,7 +115,8 @@ package object dsl { def getField(fieldName: String): UnresolvedExtractValue = UnresolvedExtractValue(expr, Literal(fieldName)) - def cast(to: DataType): Expression = Cast(expr, to) + def cast(to: DataType): Expression = + if (!expr.dataType.sameType(to)) Cast(expr, to) else expr def asc: SortOrder = SortOrder(expr, Ascending) def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) From ed7cf41249e31c1a4d2a0462cea957e22bc9a356 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 30 Aug 2019 08:32:34 +0900 Subject: [PATCH 11/20] Address reviews --- .../aggregate/HashAggregateExec.scala | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a6c527275899..0b11f40cb9fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -176,7 +176,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. + // The variables are used as aggregation buffers and each aggregate function has one more ExprCode + // to initialize its buffer slots. Only used for aggregation without keys. private var bufVars: Seq[Seq[ExprCode]] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { @@ -298,17 +299,17 @@ case class HashAggregateExec( // we totally give up splitting aggregate code. if (inputVars.forall(_.isDefined)) { val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => - val doAggVal = ctx.freshName(s"doAggregateVal_${aggNames(i)}") + val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") - val doAggValFuncName = ctx.addNewFunction(doAggVal, + val doAggFuncName = ctx.addNewFunction(doAggFunc, s""" - |private void $doAggVal($argList) throws java.io.IOException { + |private void $doAggFunc($argList) throws java.io.IOException { | ${aggCodeBlocks(i)} |} """.stripMargin) val inputVariables = args.map(_.variableName).mkString(", ") - s"$doAggValFuncName($inputVariables);" + s"$doAggFuncName($inputVariables);" } Some(splitCodes.mkString("\n").trim) } else { @@ -344,18 +345,18 @@ case class HashAggregateExec( } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => + val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } } val aggNames = functions.map(_.prettyName) - val aggCodeBlocks = aggVals.zipWithIndex.map { case (aggValsForOneFunc, i) => + val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) => val bufVarsForOneFunc = bufVars(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. - val updates = aggValsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) => + val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) => s""" |${bufVar.isNull} = ${ev.isNull}; |${bufVar.value} = ${ev.value}; @@ -364,7 +365,7 @@ case class HashAggregateExec( code""" |// do aggregate for ${aggNames(i)} |// evaluate aggregate function - |${evaluateVariables(aggValsForOneFunc)} + |${evaluateVariables(bufferEvalsForOneFunc)} |// update aggregation buffers |${updates.mkString("\n").trim} """.stripMargin @@ -385,13 +386,13 @@ case class HashAggregateExec( val splitAggCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - splitAggCode.map { updateAggValCode => + splitAggCode.map { updateAggCode => s""" |// do aggregate |// common sub-expressions |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers - |$updateAggValCode + |$updateAggCode """.stripMargin }.getOrElse { nonSplitAggCode @@ -854,8 +855,10 @@ case class HashAggregateExec( val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") - // only have DeclarativeAggregate + // To individually generate code for each aggregate function, an element in `updateExprs` holds + // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => + // only have DeclarativeAggregate e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions @@ -934,9 +937,9 @@ case class HashAggregateExec( ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) - // Computes buffer offsets for each aggregation function code + // Computes start offsets for each aggregation function code // in the underlying buffer row. - val bufferOffsets = { + val bufferStartOffsets = { val offsets = mutable.ArrayBuffer[Int]() var curOffset = 0 updateExprs.foreach { exprsForOneFunc => @@ -959,10 +962,11 @@ case class HashAggregateExec( } } - val aggCodeBlocks = unsafeRowBufferEvals.zipWithIndex - .map { case (rowBufferEvalsForOneFunc, i) => + val aggCodeBlocks = updateExprs.indices.map { i => + val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i) val boundUpdateExprsForOneFunc = boundUpdateExprs(i) - val bufferOffset = bufferOffsets(i) + val bufferOffset = bufferStartOffsets(i) + // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) => @@ -993,13 +997,13 @@ case class HashAggregateExec( val splitAggCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - splitAggCode.map { updateAggValCode => + splitAggCode.map { updateAggCode => s""" |// do aggregate |// common sub-expressions |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers - |$updateAggValCode + |$updateAggCode """.stripMargin }.getOrElse { nonSplitAggCode @@ -1026,7 +1030,7 @@ case class HashAggregateExec( val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) => val boundUpdateExprsForOneFunc = boundUpdateExprs(i) - val bufferOffset = bufferOffsets(i) + val bufferOffset = bufferStartOffsets(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) => @@ -1065,7 +1069,7 @@ case class HashAggregateExec( val splitAggCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - splitAggCode.map { updateAggValCode => + splitAggCode.map { updateAggCode => // If fast hash map is on, we first generate code to update row in fast hash map, if // the previous loop up hit fast hash map. Otherwise, update row in regular hash map. s""" @@ -1073,7 +1077,7 @@ case class HashAggregateExec( | // common sub-expressions | $effectiveCodes | // evaluate aggregate functions and update aggregation buffers - | $updateAggValCode + | $updateAggCode |} else { | $updateRowInRegularHashMap |} From df20f3e3fb38100d95514248e2fc5fcc073d4ea4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 30 Aug 2019 10:01:49 +0900 Subject: [PATCH 12/20] Fix errors --- .../org/apache/spark/sql/catalyst/dsl/package.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8ba25a7a4782..d37d81753f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -115,8 +115,13 @@ package object dsl { def getField(fieldName: String): UnresolvedExtractValue = UnresolvedExtractValue(expr, Literal(fieldName)) - def cast(to: DataType): Expression = - if (!expr.dataType.sameType(to)) Cast(expr, to) else expr + def cast(to: DataType): Expression = { + if (expr.resolved && expr.dataType.sameType(to)) { + expr + } else { + Cast(expr, to) + } + } def asc: SortOrder = SortOrder(expr, Ascending) def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) From 649db0a0f1070f38041c5559dbc7f7eeda79e132 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 30 Aug 2019 23:53:02 +0900 Subject: [PATCH 13/20] Fix a bug --- .../sql/catalyst/expressions/nullExpressions.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 293d28e93039..f54d5f167856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = eval.isNull match { - case TrueLiteral => FalseLiteral - case FalseLiteral => TrueLiteral - case v => JavaCode.isNullExpression(s"!$v") + val (value, newCode) = eval.isNull match { + case TrueLiteral => (FalseLiteral, EmptyBlock) + case FalseLiteral => (TrueLiteral, EmptyBlock) + case v => + val value = ctx.freshName("value") + (JavaCode.variable(value, BooleanType), code"boolean $value = !$v;") } - ExprCode(code = eval.code, isNull = FalseLiteral, value = value) + ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" From e7d40e32c5cbfead9e47b709011f5127b16a85cb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 31 Aug 2019 07:35:30 +0900 Subject: [PATCH 14/20] Address comments --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4c425f25f488..09ac23711739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1050,7 +1050,7 @@ object SQLConf { val CODEGEN_SPLIT_AGGREGATE_FUNC = buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled") .internal() - .doc("When true, the code generator would aggregate code into individual methods " + + .doc("When true, the code generator would split aggregate code into individual methods " + "instead of a single big method. This can be used to avoid oversized function that " + "can miss the opportunity of JIT optimization.") .booleanConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 0b11f40cb9fa..1c5342ce6b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -262,6 +262,7 @@ case class HashAggregateExec( } private def isValidParamLength(paramLength: Int): Boolean = { + // This config is only for testing sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { case null | "" => CodeGenerator.isValidParamLength(paramLength) case validLength => paramLength <= validLength.toInt From e47b16fe5e1347b3598934669262fb293ccdaca6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 31 Aug 2019 08:41:18 +0900 Subject: [PATCH 15/20] Fix --- .../aggregate/HashAggregateExec.scala | 79 ++++++++++--------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1c5342ce6b07..bdb173d4d45a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -282,46 +282,53 @@ case class HashAggregateExec( aggBufferUpdatingExprs: Seq[Seq[Expression]], aggCodeBlocks: Seq[Block], subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { - val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => - val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq - val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) - - // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit - if (isValidParamLength(paramLength)) { - Some(inputVarsForOneFunc) - } else { - None + val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } + if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { + // `SimpleExprValue`s cannot be used as an input variable for split functions, so + // we give up splitting functions if it exists in `subExprs`. + None + } else { + val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => + val inputVarsForOneFunc = aggExprsForOneFunc.map( + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) + + // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit + if (isValidParamLength(paramLength)) { + Some(inputVarsForOneFunc) + } else { + None + } } - } - // Checks if all the aggregate code can be split into pieces. - // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, - // we totally give up splitting aggregate code. - if (inputVars.forall(_.isDefined)) { - val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => - val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") - val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") - val doAggFuncName = ctx.addNewFunction(doAggFunc, - s""" - |private void $doAggFunc($argList) throws java.io.IOException { - | ${aggCodeBlocks(i)} - |} - """.stripMargin) + // Checks if all the aggregate code can be split into pieces. + // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, + // we totally give up splitting aggregate code. + if (inputVars.forall(_.isDefined)) { + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => + val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") + val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") + val doAggFuncName = ctx.addNewFunction(doAggFunc, + s""" + |private void $doAggFunc($argList) throws java.io.IOException { + | ${aggCodeBlocks(i)} + |} + """.stripMargin) - val inputVariables = args.map(_.variableName).mkString(", ") - s"$doAggFuncName($inputVariables);" - } - Some(splitCodes.mkString("\n").trim) - } else { - val errMsg = "Failed to split aggregate code into small functions because the parameter " + - "length of at least one split function went over the JVM limit: " + - CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH - if (Utils.isTesting) { - throw new IllegalStateException(errMsg) + val inputVariables = args.map(_.variableName).mkString(", ") + s"$doAggFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n").trim) } else { - logInfo(errMsg) - None + val errMsg = "Failed to split aggregate code into small functions because the parameter " + + "length of at least one split function went over the JVM limit: " + + CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logInfo(errMsg) + None + } } } } From 8e3da46df76beeb0ca27e42359558aff143ba65a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Sep 2019 10:10:50 +0900 Subject: [PATCH 16/20] Address reviews --- .../aggregate/HashAggregateExec.scala | 20 +++++++++---------- .../execution/WholeStageCodegenSuite.scala | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bdb173d4d45a..3375c79669bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -176,8 +176,8 @@ case class HashAggregateExec( } } - // The variables are used as aggregation buffers and each aggregate function has one more ExprCode - // to initialize its buffer slots. Only used for aggregation without keys. + // The variables are used as aggregation buffers and each aggregate function has one or more + // ExprCode to initialize its buffer slots. Only used for aggregation without keys. private var bufVars: Seq[Seq[ExprCode]] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { @@ -263,14 +263,14 @@ case class HashAggregateExec( private def isValidParamLength(paramLength: Int): Boolean = { // This config is only for testing - sqlContext.getConf("spark.sql.HashAggregateExec.isValidParamLength", null) match { + sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match { case null | "" => CodeGenerator.isValidParamLength(paramLength) case validLength => paramLength <= validLength.toInt } } // Splits aggregate code into small functions because the most of JVM implementations - // can not compile too long functions. + // can not compile too long functions. Returns None if we are not able to split the given code. // // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual // function for each aggregation function (e.g., SUM and AVG). For example, in a query @@ -391,10 +391,10 @@ case class HashAggregateExec( if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { - val splitAggCode = splitAggregateExpressions( + val maybeSplitCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - splitAggCode.map { updateAggCode => + maybeSplitCode.map { updateAggCode => s""" |// do aggregate |// common sub-expressions @@ -1002,10 +1002,10 @@ case class HashAggregateExec( if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { - val splitAggCode = splitAggregateExpressions( + val maybeSplitCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - splitAggCode.map { updateAggCode => + maybeSplitCode.map { updateAggCode => s""" |// do aggregate |// common sub-expressions @@ -1074,10 +1074,10 @@ case class HashAggregateExec( if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { - val splitAggCode = splitAggregateExpressions( + val maybeSplitCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - splitAggCode.map { updateAggCode => + maybeSplitCode.map { updateAggCode => // If fast hash map is on, we first generate code to update row in fast hash map, if // the previous loop up hit fast hash map. Otherwise, update row in regular hash map. s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index eaad58844c19..d8727d5b584f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -403,7 +403,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { withSQLConf( SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.HashAggregateExec.isValidParamLength" -> "0") { + "spark.sql.HashAggregateExec.validParamLength" -> "0") { withTable("t") { val expectedErrMsg = "Failed to split aggregate code into small functions" Seq( From d1e06a600d262469542ac671bf4ae85a4da32706 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Sep 2019 12:57:03 +0900 Subject: [PATCH 17/20] Address comments --- .../aggregate/HashAggregateExec.scala | 116 +++++++----------- 1 file changed, 41 insertions(+), 75 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 3375c79669bc..5a643c55f298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -379,35 +379,25 @@ case class HashAggregateExec( """.stripMargin } - lazy val nonSplitAggCode = { - s""" - |// do aggregate - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate functions and update aggregation buffers - |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} - """.stripMargin - } - - if (conf.codegenSplitAggregateFunc && + val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val maybeSplitCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - maybeSplitCode.map { updateAggCode => - s""" - |// do aggregate - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate functions and update aggregation buffers - |$updateAggCode - """.stripMargin - }.getOrElse { - nonSplitAggCode + maybeSplitCode.getOrElse { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } } else { - nonSplitAggCode + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } + + s""" + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$codeToEvalAggFunc + """.stripMargin } private val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -991,34 +981,24 @@ case class HashAggregateExec( """.stripMargin } - lazy val nonSplitAggCode = { - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate functions and update aggregation buffers - |${aggCodeBlocks.fold(EmptyBlock)(_ + _)} - """.stripMargin - } - - if (conf.codegenSplitAggregateFunc && + val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val maybeSplitCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - maybeSplitCode.map { updateAggCode => - s""" - |// do aggregate - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate functions and update aggregation buffers - |$updateAggCode - """.stripMargin - }.getOrElse { - nonSplitAggCode + maybeSplitCode.getOrElse { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } } else { - nonSplitAggCode + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } + + s""" + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$codeToEvalAggFunc + """.stripMargin } val updateRowInHashMap: String = { @@ -1056,46 +1036,32 @@ case class HashAggregateExec( """.stripMargin } - lazy val nonSplitAggCode = { - // If vectorized fast hash map is on, we first generate code to update row - // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. - // Otherwise, update row in regular hash map. - s""" - |if ($fastRowBuffer != null) { - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | ${aggCodeBlocks.fold(EmptyBlock)(_ + _)} - |} else { - | $updateRowInRegularHashMap - |} - """.stripMargin - } - if (conf.codegenSplitAggregateFunc && + val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val maybeSplitCode = splitAggregateExpressions( ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) - maybeSplitCode.map { updateAggCode => - // If fast hash map is on, we first generate code to update row in fast hash map, if - // the previous loop up hit fast hash map. Otherwise, update row in regular hash map. - s""" - |if ($fastRowBuffer != null) { - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate functions and update aggregation buffers - | $updateAggCode - |} else { - | $updateRowInRegularHashMap - |} - """.stripMargin - }.getOrElse { - nonSplitAggCode + maybeSplitCode.getOrElse { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } } else { - nonSplitAggCode + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } + + // If vectorized fast hash map is on, we first generate code to update row + // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map. + // Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate functions and update aggregation buffers + | $codeToEvalAggFunc + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin } else { // If row-based hash map is on and the previous loop up hit fast hash map, // we reuse regular hash buffer to update row of fast hash map. From 50889391bdac1f18085f33b588af7d8faf2af922 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 5 Sep 2019 21:41:22 +0900 Subject: [PATCH 18/20] Fix style issue --- .../aggregate/HashAggregateExec.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 5a643c55f298..9242583d3671 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -1021,19 +1021,19 @@ case class HashAggregateExec( val bufferOffset = bufferStartOffsets(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. - val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) => - val updateExpr = boundUpdateExprsForOneFunc(j) - val dt = updateExpr.dataType - val nullable = updateExpr.nullable - CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, - isVectorized = true) - } - code""" - |// evaluate aggregate function for ${aggNames(i)} - |${evaluateVariables(fastRowEvalsForOneFunc)} - |// update fast row - |${updateRowBuffer.mkString("\n").trim} - """.stripMargin + val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, + isVectorized = true) + } + code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(fastRowEvalsForOneFunc)} + |// update fast row + |${updateRowBuffer.mkString("\n").trim} + """.stripMargin } From 20b310f7dcb4b39c8c0b49dace545f34710c4727 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 5 Sep 2019 21:42:30 +0900 Subject: [PATCH 19/20] Trun off the split mode temporarily for Jenkins tests --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 09ac23711739..c353faa20a16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1054,7 +1054,7 @@ object SQLConf { "instead of a single big method. This can be used to avoid oversized function that " + "can miss the opportunity of JIT optimization.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val MAX_NESTED_VIEW_DEPTH = buildConf("spark.sql.view.maxNestedViewDepth") From ab8de8328599832b2f535c124aa5cb55a006277e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 6 Sep 2019 06:10:22 +0900 Subject: [PATCH 20/20] Revert "Trun off the split mode temporarily for Jenkins tests" This reverts commit 20b310f7dcb4b39c8c0b49dace545f34710c4727. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c353faa20a16..09ac23711739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1054,7 +1054,7 @@ object SQLConf { "instead of a single big method. This can be used to avoid oversized function that " + "can miss the opportunity of JIT optimization.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val MAX_NESTED_VIEW_DEPTH = buildConf("spark.sql.view.maxNestedViewDepth")