Skip to content

Commit 9730b05

Browse files
lianchenguzadude
authored andcommitted
[SPARK-18403][SQL] Fix unsafe data false sharing issue in ObjectHashAggregateExec
## What changes were proposed in this pull request? This PR fixes a random OOM issue occurred while running `ObjectHashAggregateSuite`. This issue can be steadily reproduced under the following conditions: 1. The aggregation must be evaluated using `ObjectHashAggregateExec`; 2. There must be an input column whose data type involves `ArrayType` (an input column of `MapType` may even cause SIGSEGV); 3. Sort-based aggregation fallback must be triggered during evaluation. The root cause is that while falling back to sort-based aggregation, we must sort and feed already evaluated partial aggregation buffers living in the hash map to the sort-based aggregator using an external sorter. However, the underlying mutable byte buffer of `UnsafeRow`s produced by the iterator of the external sorter is reused and may get overwritten when the iterator steps forward. After the last entry is consumed, the byte buffer points to a block of uninitialized memory filled by `5a`. Therefore, while reading an `UnsafeArrayData` out of the `UnsafeRow`, `5a5a5a5a` is treated as array size and triggers a memory allocation for a ridiculously large array and immediately blows up the JVM with an OOM. To fix this issue, we only need to add `.copy()` accordingly. ## How was this patch tested? New regression test case added in `ObjectHashAggregateSuite`. Author: Cheng Lian <lian@databricks.com> Closes apache#15976 from liancheng/investigate-oom.
1 parent 88521ad commit 9730b05

2 files changed

Lines changed: 101 additions & 74 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ class SortBasedAggregator(
262262
// Firstly, update the aggregation buffer with input rows.
263263
while (hasNextInput &&
264264
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
265-
processRow(result.aggregationBuffer, inputIterator.getValue)
265+
// Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
266+
// overwritten when `inputIterator` steps forward, we need to do a deep copy here.
267+
processRow(result.aggregationBuffer, inputIterator.getValue.copy())
266268
hasNextInput = inputIterator.next()
267269
}
268270

@@ -271,7 +273,12 @@ class SortBasedAggregator(
271273
// be called after calling processRow.
272274
while (hasNextAggBuffer &&
273275
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
274-
mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
276+
mergeAggregationBuffers(
277+
result.aggregationBuffer,
278+
// Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
279+
// overwritten when `inputIterator` steps forward, we need to do a deep copy here.
280+
initialAggBufferIterator.getValue.copy()
281+
)
275282
hasNextAggBuffer = initialAggBufferIterator.next()
276283
}
277284

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
205205
// A TypedImperativeAggregate function
206206
val typed = percentile_approx($"c0", 0.5)
207207

208-
// A Hive UDAF without partial aggregation support
209-
val withoutPartial = function("hive_max", $"c1")
210-
211208
// A Spark SQL native aggregate function with partial aggregation support that can be executed
212209
// by the Tungsten `HashAggregateExec`
213-
val withPartialUnsafe = max($"c2")
210+
val withPartialUnsafe = max($"c1")
214211

215212
// A Spark SQL native aggregate function with partial aggregation support that can only be
216213
// executed by the Tungsten `HashAggregateExec`
217-
val withPartialSafe = max($"c3")
214+
val withPartialSafe = max($"c2")
218215

219216
// A Spark SQL native distinct aggregate function
220-
val withDistinct = countDistinct($"c4")
217+
val withDistinct = countDistinct($"c3")
221218

222219
val allAggs = Seq(
223220
"typed" -> typed,
224-
"without partial" -> withoutPartial,
225221
"with partial + unsafe" -> withPartialUnsafe,
226222
"with partial + safe" -> withPartialSafe,
227223
"with distinct" -> withDistinct
@@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
276272
// Generates a random schema for the randomized data generator
277273
val schema = new StructType()
278274
.add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true)
279-
.add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true)
280-
.add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
281-
.add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
282-
.add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true)
275+
.add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
276+
.add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
277+
.add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true)
283278

284279
logInfo(
285280
s"""Using the following random schema to generate all the randomized aggregation tests:
@@ -325,70 +320,67 @@ class ObjectHashAggregateSuite
325320

326321
// Currently Spark SQL doesn't support evaluating distinct aggregate function together
327322
// with aggregate functions without partial aggregation support.
328-
if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
329-
// TODO Re-enables them after fixing SPARK-18403
330-
ignore(
331-
s"randomized aggregation test - " +
332-
s"${names.mkString("[", ", ", "]")} - " +
333-
s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
334-
s"with ${if (emptyInput) "empty" else "non-empty"} input"
335-
) {
336-
var expected: Seq[Row] = null
337-
var actual1: Seq[Row] = null
338-
var actual2: Seq[Row] = null
339-
340-
// Disables `ObjectHashAggregateExec` to obtain a standard answer
341-
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
342-
val aggDf = doAggregation(df)
343-
344-
if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) {
345-
assert(containsSortAggregateExec(aggDf))
346-
assert(!containsObjectHashAggregateExec(aggDf))
347-
assert(!containsHashAggregateExec(aggDf))
348-
} else {
349-
assert(!containsSortAggregateExec(aggDf))
350-
assert(!containsObjectHashAggregateExec(aggDf))
351-
assert(containsHashAggregateExec(aggDf))
352-
}
353-
354-
expected = aggDf.collect().toSeq
323+
test(
324+
s"randomized aggregation test - " +
325+
s"${names.mkString("[", ", ", "]")} - " +
326+
s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
327+
s"with ${if (emptyInput) "empty" else "non-empty"} input"
328+
) {
329+
var expected: Seq[Row] = null
330+
var actual1: Seq[Row] = null
331+
var actual2: Seq[Row] = null
332+
333+
// Disables `ObjectHashAggregateExec` to obtain a standard answer
334+
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
335+
val aggDf = doAggregation(df)
336+
337+
if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) {
338+
assert(containsSortAggregateExec(aggDf))
339+
assert(!containsObjectHashAggregateExec(aggDf))
340+
assert(!containsHashAggregateExec(aggDf))
341+
} else {
342+
assert(!containsSortAggregateExec(aggDf))
343+
assert(!containsObjectHashAggregateExec(aggDf))
344+
assert(containsHashAggregateExec(aggDf))
355345
}
356346

357-
// Enables `ObjectHashAggregateExec`
358-
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
359-
val aggDf = doAggregation(df)
360-
361-
if (aggs.contains(typed) && !aggs.contains(withoutPartial)) {
362-
assert(!containsSortAggregateExec(aggDf))
363-
assert(containsObjectHashAggregateExec(aggDf))
364-
assert(!containsHashAggregateExec(aggDf))
365-
} else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) {
366-
assert(containsSortAggregateExec(aggDf))
367-
assert(!containsObjectHashAggregateExec(aggDf))
368-
assert(!containsHashAggregateExec(aggDf))
369-
} else {
370-
assert(!containsSortAggregateExec(aggDf))
371-
assert(!containsObjectHashAggregateExec(aggDf))
372-
assert(containsHashAggregateExec(aggDf))
373-
}
374-
375-
// Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
376-
// big enough) to obtain a result to be checked.
377-
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
378-
actual1 = aggDf.collect().toSeq
379-
}
380-
381-
// Enables sort-based aggregation fallback to obtain another result to be checked.
382-
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
383-
// Here we are not reusing `aggDf` because the physical plan in `aggDf` is
384-
// cached and won't be re-planned using the new fallback threshold.
385-
actual2 = doAggregation(df).collect().toSeq
386-
}
347+
expected = aggDf.collect().toSeq
348+
}
349+
350+
// Enables `ObjectHashAggregateExec`
351+
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
352+
val aggDf = doAggregation(df)
353+
354+
if (aggs.contains(typed)) {
355+
assert(!containsSortAggregateExec(aggDf))
356+
assert(containsObjectHashAggregateExec(aggDf))
357+
assert(!containsHashAggregateExec(aggDf))
358+
} else if (aggs.contains(withPartialSafe)) {
359+
assert(containsSortAggregateExec(aggDf))
360+
assert(!containsObjectHashAggregateExec(aggDf))
361+
assert(!containsHashAggregateExec(aggDf))
362+
} else {
363+
assert(!containsSortAggregateExec(aggDf))
364+
assert(!containsObjectHashAggregateExec(aggDf))
365+
assert(containsHashAggregateExec(aggDf))
387366
}
388367

389-
doubleSafeCheckRows(actual1, expected, 1e-4)
390-
doubleSafeCheckRows(actual2, expected, 1e-4)
368+
// Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
369+
// big enough) to obtain a result to be checked.
370+
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
371+
actual1 = aggDf.collect().toSeq
372+
}
373+
374+
// Enables sort-based aggregation fallback to obtain another result to be checked.
375+
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
376+
// Here we are not reusing `aggDf` because the physical plan in `aggDf` is
377+
// cached and won't be re-planned using the new fallback threshold.
378+
actual2 = doAggregation(df).collect().toSeq
379+
}
391380
}
381+
382+
doubleSafeCheckRows(actual1, expected, 1e-4)
383+
doubleSafeCheckRows(actual2, expected, 1e-4)
392384
}
393385
}
394386
}
@@ -425,7 +417,35 @@ class ObjectHashAggregateSuite
425417
}
426418
}
427419

428-
private def function(name: String, args: Column*): Column = {
429-
Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
420+
test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") {
421+
// SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
422+
// certain aggregate functions. To reproduce this issue, the following conditions must be
423+
// met:
424+
//
425+
// 1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
426+
// 2. There must be an input column whose data type involves `ArrayType` or `MapType`;
427+
// 3. Sort-based aggregation fallback must be triggered during evaluation.
428+
withSQLConf(
429+
SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
430+
SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"
431+
) {
432+
checkAnswer(
433+
Seq
434+
.fill(2)(Tuple1(Array.empty[Int]))
435+
.toDF("c0")
436+
.groupBy(lit(1))
437+
.agg(typed_count($"c0"), max($"c0")),
438+
Row(1, 2, Array.empty[Int])
439+
)
440+
441+
checkAnswer(
442+
Seq
443+
.fill(2)(Tuple1(Map.empty[Int, Int]))
444+
.toDF("c0")
445+
.groupBy(lit(1))
446+
.agg(typed_count($"c0"), first($"c0")),
447+
Row(1, 2, Map.empty[Int, Int])
448+
)
449+
}
430450
}
431451
}

0 commit comments

Comments
 (0)