-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31973][SQL] Skip partial aggregates if grouping keys have high cardinality #28804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
db8a62d
9a59925
feacdcf
ab98ea4
2e102d1
5fa601b
220eaed
452b632
68dd5a3
692fd1b
f1b6ac1
05c891f
dd3c56a
cb8b922
7952aa7
56c95e2
43237ba
99c1d22
d2873a3
afc2903
7766401
75125d9
3ca81ae
8850777
26a2fd6
c088816
c49f106
69f1d71
c9a415d
ceaa4e5
2ae5525
0a186f0
11572a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2196,6 +2196,13 @@ object SQLConf { | |
| .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") | ||
| .createWithDefault(16) | ||
|
|
||
| val SKIP_PARTIAL_AGGREGATE_ENABLED = | ||
| buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") | ||
| .internal() | ||
| .doc("Avoid sort/spill to disk during partial aggregation") | ||
| .booleanConf | ||
| .createWithDefault(true) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we use a threadhold + column stats instead of this boolean config?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didnt get the threshold part. Can you pleas elaborate
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That meant a ratio of a distinct row count and total row count in group-by key column stats. For example, if a number
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @maropu for explaining, I will make this change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu This is very useful suggestion. One issue is columns stats are rarely computed. We came across this work in HIVE https://issues.apache.org/jira/browse/HIVE-291. They turn off map side aggregate (i.e., partial aggregate will be pass through) in Physical operator (i.e., Group-By operator) if map-side aggregation reduce the entries by at least half and they look at 100000 rows to do that (ref: patch https://issues.apache.org/jira/secure/attachment/12400257/291.1.txt). Should we do something similar in HashAggregateExec here ? Any thoughts on this ?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think whether that approach improves performance depends on IO performance, but the idea looks interesting to me. WDYT? @cloud-fan |
||
|
|
||
| val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") | ||
| .doc("Compression codec used in writing of AVRO files. Supported codecs: " + | ||
| "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") | ||
|
|
@@ -2922,6 +2929,8 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) | ||
|
|
||
| def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) | ||
|
|
||
| def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) | ||
|
|
||
| def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -353,4 +353,8 @@ object AggUtils { | |
|
|
||
| finalAndCompleteAggregate :: Nil | ||
| } | ||
|
|
||
| def areAggExpressionsPartial(modes: Seq[AggregateMode]): Boolean = { | ||
karuppayya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| modes.nonEmpty && modes.forall(_ == Partial) | ||
|
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,8 @@ case class HashAggregateExec( | |
|
|
||
| require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) | ||
|
|
||
| override def needStopCheck: Boolean = skipPartialAggregate | ||
|
||
|
|
||
| override lazy val allAttributes: AttributeSeq = | ||
| child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ | ||
| aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) | ||
|
|
@@ -72,6 +74,8 @@ case class HashAggregateExec( | |
| "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), | ||
| "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), | ||
| "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), | ||
| "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, | ||
| "number of skipped records for partial aggregates"), | ||
|
||
| "avgHashProbe" -> | ||
| SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) | ||
|
|
||
|
|
@@ -409,6 +413,11 @@ case class HashAggregateExec( | |
| private var fastHashMapTerm: String = _ | ||
| private var isFastHashMapEnabled: Boolean = false | ||
|
|
||
| private var avoidSpillInPartialAggregateTerm: String = _ | ||
| private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate && | ||
karuppayya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty | ||
|
||
| private var outputFunc: String = _ | ||
|
|
||
| // whether a vectorized hashmap is used instead | ||
| // we have decided to always use the row-based hashmap, | ||
| // but the vectorized hashmap can still be switched on for testing and benchmarking purposes. | ||
|
|
@@ -628,6 +637,8 @@ case class HashAggregateExec( | |
| |${consume(ctx, resultVars)} | ||
| """.stripMargin | ||
| } | ||
|
|
||
|
|
||
| ctx.addNewFunction(funcName, | ||
| s""" | ||
| |private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) | ||
|
|
@@ -680,6 +691,10 @@ case class HashAggregateExec( | |
|
|
||
| private def doProduceWithKeys(ctx: CodegenContext): String = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you apply this optimization only for the with-key case?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will be only one key for the map inwithout-key case and the optimization will not apply.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. |
||
| val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") | ||
| avoidSpillInPartialAggregateTerm = ctx. | ||
| addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") | ||
| val childrenConsumed = ctx. | ||
| addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") | ||
| if (sqlContext.conf.enableTwoLevelAggMap) { | ||
| enableTwoLevelHashMap(ctx) | ||
| } else if (sqlContext.conf.enableVectorizedHashMap) { | ||
|
|
@@ -750,18 +765,19 @@ case class HashAggregateExec( | |
| finishRegularHashMap | ||
| } | ||
|
|
||
| outputFunc = generateResultFunction(ctx) | ||
| val doAggFuncName = ctx.addNewFunction(doAgg, | ||
| s""" | ||
| |private void $doAgg() throws java.io.IOException { | ||
| | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
| | $childrenConsumed = true; | ||
|
||
| | $finishHashMap | ||
| |} | ||
| """.stripMargin) | ||
|
|
||
| // generate code for output | ||
| val keyTerm = ctx.freshName("aggKey") | ||
| val bufferTerm = ctx.freshName("aggBuffer") | ||
| val outputFunc = generateResultFunction(ctx) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you move this line into the line 771?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| def outputFromFastHashMap: String = { | ||
| if (isFastHashMapEnabled) { | ||
|
|
@@ -833,11 +849,18 @@ case class HashAggregateExec( | |
| s""" | ||
| |if (!$initAgg) { | ||
| | $initAgg = true; | ||
| | $avoidSpillInPartialAggregateTerm = | ||
| | ${Utils.isTesting} && $skipPartialAggregate; | ||
| | $createFastHashMap | ||
| | $hashMapTerm = $thisPlan.createHashMap(); | ||
| | long $beforeAgg = System.nanoTime(); | ||
| | $doAggFuncName(); | ||
| | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); | ||
| | $shouldStopCheckCode; | ||
| |} | ||
| |if (!$childrenConsumed) { | ||
| | $doAggFuncName(); | ||
| | $shouldStopCheckCode; | ||
| |} | ||
| |// output the result | ||
| |$outputFromFastHashMap | ||
|
|
@@ -878,43 +901,51 @@ case class HashAggregateExec( | |
| } | ||
|
|
||
| val oomeClassName = classOf[SparkOutOfMemoryError].getName | ||
|
|
||
| val findOrInsertRegularHashMap: String = | ||
| s""" | ||
| |// generate grouping key | ||
| |${unsafeRowKeyCode.code} | ||
| |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); | ||
| |if ($checkFallbackForBytesToBytesMap) { | ||
| | // try to get the buffer from hash map | ||
| | $unsafeRowBuffer = | ||
| | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); | ||
| |} | ||
| |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based | ||
| |// aggregation after processing all input rows. | ||
| |if ($unsafeRowBuffer == null) { | ||
| | if ($sorterTerm == null) { | ||
| | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
| | } else { | ||
| | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
| |if (!$avoidSpillInPartialAggregateTerm) { | ||
| | // generate grouping key | ||
| | ${unsafeRowKeyCode.code} | ||
| | int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); | ||
| | if ($checkFallbackForBytesToBytesMap) { | ||
| | // try to get the buffer from hash map | ||
| | $unsafeRowBuffer = | ||
| | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); | ||
| | } | ||
| | $resetCounter | ||
| | // the hash map had be spilled, it should have enough memory now, | ||
| | // try to allocate buffer again. | ||
| | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( | ||
| | $unsafeRowKeys, $unsafeRowKeyHash); | ||
| | if ($unsafeRowBuffer == null) { | ||
| | // failed to allocate the first page | ||
| | throw new $oomeClassName("No enough memory for aggregation"); | ||
| | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based | ||
| | // aggregation after processing all input rows. | ||
| | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | ||
|
||
| | // If sort/spill to disk is disabled, nothing is done. | ||
| | // Aggregation buffer is created later | ||
| | if ($skipPartialAggregate) { | ||
| | $avoidSpillInPartialAggregateTerm = true; | ||
| | } else { | ||
| | if ($sorterTerm == null) { | ||
| | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
| | } else { | ||
| | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
| | } | ||
| | $resetCounter | ||
| | // the hash map had be spilled, it should have enough memory now, | ||
| | // try to allocate buffer again. | ||
| | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( | ||
| | $unsafeRowKeys, $unsafeRowKeyHash); | ||
| | if ($unsafeRowBuffer == null) { | ||
| | // failed to allocate the first page | ||
| | throw new $oomeClassName("No enough memory for aggregation"); | ||
| | } | ||
| | } | ||
| | } | ||
| |} | ||
| """.stripMargin | ||
|
|
||
| val partTerm = metricTerm(ctx, "partialAggSkipped") | ||
| val findOrInsertHashMap: String = { | ||
| if (isFastHashMapEnabled) { | ||
| val insertCode = if (isFastHashMapEnabled) { | ||
| // If fast hash map is on, we first generate code to probe and update the fast hash map. | ||
| // If the probe is successful the corresponding fast row buffer will hold the mutable row. | ||
| s""" | ||
| |if ($checkFallbackForGeneratedHashMap) { | ||
| |if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { | ||
| | ${fastRowKeys.map(_.code).mkString("\n")} | ||
| | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | ||
| | $fastRowBuffer = $fastHashMapTerm.findOrInsert( | ||
|
|
@@ -929,6 +960,18 @@ case class HashAggregateExec( | |
| } else { | ||
| findOrInsertRegularHashMap | ||
| } | ||
| val initExpr = declFunctions.flatMap(f => f.initialValues) | ||
| val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) | ||
| s""" | ||
| |$insertCode | ||
| |// Create an empty aggregation buffer | ||
| |if ($avoidSpillInPartialAggregateTerm) { | ||
| | ${unsafeRowKeyCode.code} | ||
| | ${emptyBufferKeyCode.code} | ||
| | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; | ||
| | $partTerm.add(1); | ||
| |} | ||
| |""".stripMargin | ||
| } | ||
|
|
||
| val inputAttr = aggregateBufferAttributes ++ inputAttributes | ||
|
|
@@ -1005,7 +1048,7 @@ case class HashAggregateExec( | |
| } | ||
|
|
||
| val updateRowInHashMap: String = { | ||
| if (isFastHashMapEnabled) { | ||
| val updateRowinMap = if (isFastHashMapEnabled) { | ||
| if (isVectorizedHashMapEnabled) { | ||
| ctx.INPUT_ROW = fastRowBuffer | ||
| val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => | ||
|
|
@@ -1080,6 +1123,12 @@ case class HashAggregateExec( | |
| } else { | ||
| updateRowInRegularHashMap | ||
| } | ||
| s""" | ||
| |$updateRowinMap | ||
| |if ($avoidSpillInPartialAggregateTerm) { | ||
| | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); | ||
| |} | ||
| |""".stripMargin | ||
| } | ||
|
|
||
| val declareRowBuffer: String = if (isFastHashMapEnabled) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this only works for hash aggregate but not the sort based aggregate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I beleive this heuristic can be applied for sort based aggregation as well. I started with Hash based aggregate, I will create a new PR for sort based aggregation.