@@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
205205 // A TypedImperativeAggregate function
206206 val typed = percentile_approx($" c0" , 0.5 )
207207
208- // A Hive UDAF without partial aggregation support
209- val withoutPartial = function(" hive_max" , $" c1" )
210-
211208 // A Spark SQL native aggregate function with partial aggregation support that can be executed
212209 // by the Tungsten `HashAggregateExec`
213- val withPartialUnsafe = max($" c2 " )
210+ val withPartialUnsafe = max($" c1 " )
214211
215212 // A Spark SQL native aggregate function with partial aggregation support that can only be
216213 // executed by the Tungsten `HashAggregateExec`
217- val withPartialSafe = max($" c3 " )
214+ val withPartialSafe = max($" c2 " )
218215
219216 // A Spark SQL native distinct aggregate function
220- val withDistinct = countDistinct($" c4 " )
217+ val withDistinct = countDistinct($" c3 " )
221218
222219 val allAggs = Seq (
223220 " typed" -> typed,
224- " without partial" -> withoutPartial,
225221 " with partial + unsafe" -> withPartialUnsafe,
226222 " with partial + safe" -> withPartialSafe,
227223 " with distinct" -> withDistinct
@@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
276272 // Generates a random schema for the randomized data generator
277273 val schema = new StructType ()
278274 .add(" c0" , numericTypes(random.nextInt(numericTypes.length)), nullable = true )
279- .add(" c1" , orderedTypes(random.nextInt(orderedTypes.length)), nullable = true )
280- .add(" c2" , fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true )
281- .add(" c3" , varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true )
282- .add(" c4" , allTypes(random.nextInt(allTypes.length)), nullable = true )
275+ .add(" c1" , fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true )
276+ .add(" c2" , varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true )
277+ .add(" c3" , allTypes(random.nextInt(allTypes.length)), nullable = true )
283278
284279 logInfo(
285280 s """ Using the following random schema to generate all the randomized aggregation tests:
@@ -325,70 +320,67 @@ class ObjectHashAggregateSuite
325320
326321 // Currently Spark SQL doesn't support evaluating distinct aggregate function together
327322 // with aggregate functions without partial aggregation support.
328- if (! (aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
329- // TODO Re-enables them after fixing SPARK-18403
330- ignore(
331- s " randomized aggregation test - " +
332- s " ${names.mkString(" [" , " , " , " ]" )} - " +
333- s " ${if (withGroupingKeys) " with" else " without" } grouping keys - " +
334- s " with ${if (emptyInput) " empty" else " non-empty" } input "
335- ) {
336- var expected : Seq [Row ] = null
337- var actual1 : Seq [Row ] = null
338- var actual2 : Seq [Row ] = null
339-
340- // Disables `ObjectHashAggregateExec` to obtain a standard answer
341- withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " false" ) {
342- val aggDf = doAggregation(df)
343-
344- if (aggs.intersect(Seq (withoutPartial, withPartialSafe, typed)).nonEmpty) {
345- assert(containsSortAggregateExec(aggDf))
346- assert(! containsObjectHashAggregateExec(aggDf))
347- assert(! containsHashAggregateExec(aggDf))
348- } else {
349- assert(! containsSortAggregateExec(aggDf))
350- assert(! containsObjectHashAggregateExec(aggDf))
351- assert(containsHashAggregateExec(aggDf))
352- }
353-
354- expected = aggDf.collect().toSeq
323+ test(
324+ s " randomized aggregation test - " +
325+ s " ${names.mkString(" [" , " , " , " ]" )} - " +
326+ s " ${if (withGroupingKeys) " with" else " without" } grouping keys - " +
327+ s " with ${if (emptyInput) " empty" else " non-empty" } input "
328+ ) {
329+ var expected : Seq [Row ] = null
330+ var actual1 : Seq [Row ] = null
331+ var actual2 : Seq [Row ] = null
332+
333+ // Disables `ObjectHashAggregateExec` to obtain a standard answer
334+ withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " false" ) {
335+ val aggDf = doAggregation(df)
336+
337+ if (aggs.intersect(Seq (withPartialSafe, typed)).nonEmpty) {
338+ assert(containsSortAggregateExec(aggDf))
339+ assert(! containsObjectHashAggregateExec(aggDf))
340+ assert(! containsHashAggregateExec(aggDf))
341+ } else {
342+ assert(! containsSortAggregateExec(aggDf))
343+ assert(! containsObjectHashAggregateExec(aggDf))
344+ assert(containsHashAggregateExec(aggDf))
355345 }
356346
357- // Enables `ObjectHashAggregateExec`
358- withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " true" ) {
359- val aggDf = doAggregation(df)
360-
361- if (aggs.contains(typed) && ! aggs.contains(withoutPartial)) {
362- assert(! containsSortAggregateExec(aggDf))
363- assert(containsObjectHashAggregateExec(aggDf))
364- assert(! containsHashAggregateExec(aggDf))
365- } else if (aggs.intersect(Seq (withoutPartial, withPartialSafe)).nonEmpty) {
366- assert(containsSortAggregateExec(aggDf))
367- assert(! containsObjectHashAggregateExec(aggDf))
368- assert(! containsHashAggregateExec(aggDf))
369- } else {
370- assert(! containsSortAggregateExec(aggDf))
371- assert(! containsObjectHashAggregateExec(aggDf))
372- assert(containsHashAggregateExec(aggDf))
373- }
374-
375- // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
376- // big enough) to obtain a result to be checked.
377- withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 100" ) {
378- actual1 = aggDf.collect().toSeq
379- }
380-
381- // Enables sort-based aggregation fallback to obtain another result to be checked.
382- withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 3" ) {
383- // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
384- // cached and won't be re-planned using the new fallback threshold.
385- actual2 = doAggregation(df).collect().toSeq
386- }
347+ expected = aggDf.collect().toSeq
348+ }
349+
350+ // Enables `ObjectHashAggregateExec`
351+ withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> " true" ) {
352+ val aggDf = doAggregation(df)
353+
354+ if (aggs.contains(typed)) {
355+ assert(! containsSortAggregateExec(aggDf))
356+ assert(containsObjectHashAggregateExec(aggDf))
357+ assert(! containsHashAggregateExec(aggDf))
358+ } else if (aggs.contains(withPartialSafe)) {
359+ assert(containsSortAggregateExec(aggDf))
360+ assert(! containsObjectHashAggregateExec(aggDf))
361+ assert(! containsHashAggregateExec(aggDf))
362+ } else {
363+ assert(! containsSortAggregateExec(aggDf))
364+ assert(! containsObjectHashAggregateExec(aggDf))
365+ assert(containsHashAggregateExec(aggDf))
387366 }
388367
389- doubleSafeCheckRows(actual1, expected, 1e-4 )
390- doubleSafeCheckRows(actual2, expected, 1e-4 )
368+ // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
369+ // big enough) to obtain a result to be checked.
370+ withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 100" ) {
371+ actual1 = aggDf.collect().toSeq
372+ }
373+
374+ // Enables sort-based aggregation fallback to obtain another result to be checked.
375+ withSQLConf(SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 3" ) {
376+ // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
377+ // cached and won't be re-planned using the new fallback threshold.
378+ actual2 = doAggregation(df).collect().toSeq
379+ }
391380 }
381+
382+ doubleSafeCheckRows(actual1, expected, 1e-4 )
383+ doubleSafeCheckRows(actual2, expected, 1e-4 )
392384 }
393385 }
394386 }
@@ -425,7 +417,35 @@ class ObjectHashAggregateSuite
425417 }
426418 }
427419
428- private def function (name : String , args : Column * ): Column = {
429- Column (UnresolvedFunction (FunctionIdentifier (name), args.map(_.expr), isDistinct = false ))
420+ test(" SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec" ) {
421+ // SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
422+ // certain aggregate functions. To reproduce this issue, the following conditions must be
423+ // met:
424+ //
425+ // 1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
426+ // 2. There must be an input column whose data type involves `ArrayType` or `MapType`;
427+ // 3. Sort-based aggregation fallback must be triggered during evaluation.
428+ withSQLConf(
429+ SQLConf .USE_OBJECT_HASH_AGG .key -> " true" ,
430+ SQLConf .OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD .key -> " 1"
431+ ) {
432+ checkAnswer(
433+ Seq
434+ .fill(2 )(Tuple1 (Array .empty[Int ]))
435+ .toDF(" c0" )
436+ .groupBy(lit(1 ))
437+ .agg(typed_count($" c0" ), max($" c0" )),
438+ Row (1 , 2 , Array .empty[Int ])
439+ )
440+
441+ checkAnswer(
442+ Seq
443+ .fill(2 )(Tuple1 (Map .empty[Int , Int ]))
444+ .toDF(" c0" )
445+ .groupBy(lit(1 ))
446+ .agg(typed_count($" c0" ), first($" c0" )),
447+ Row (1 , 2 , Map .empty[Int , Int ])
448+ )
449+ }
430450 }
431451}
0 commit comments