diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e56..0eee4f31196a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -295,6 +295,11 @@ case class HashAggregateExec( private var hashMapTerm: String = _ private var sorterTerm: String = _ + // Becasue Dataset.show/take methods will end of iteraton before reaching the end of all rows, + // we may not release resources then and cause memory leak. So we need to hold the reference + // of the hash map if it is created and release the resources after task completion. + private var hashMapToRelease: UnsafeFixedWidthAggregationMap = _ + /** * This is called by generated Java class, should be public. */ @@ -302,17 +307,23 @@ case class HashAggregateExec( // create initialized aggregate buffer val initExpr = declFunctions.flatMap(f => f.initialValues) val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + val context = TaskContext.get() // create hashMap - new UnsafeFixedWidthAggregationMap( + hashMapToRelease = new UnsafeFixedWidthAggregationMap( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + context.taskMemoryManager(), 1024 * 16, // initial capacity - TaskContext.get().taskMemoryManager().pageSizeBytes, + context.taskMemoryManager().pageSizeBytes, false // disable tracking of performance metrics ) + + // Release the resources of the hash map when the end of task. + context.addTaskCompletionListener(_ => hashMapToRelease.free()) + + hashMapToRelease } def getTaskMemoryManager(): TaskMemoryManager = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 81fa8cbf2238..37732ff6271f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataTypes, IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -1051,6 +1051,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsDouble, arrayDouble) checkDataset(dsString, arrayString) } + + test("SPARK-18487: Add completion listener to HashAggregate to avoid memory leak") { + val rng = new scala.util.Random(42) + val data = sparkContext.parallelize(Seq.tabulate(100) { i => + Row(Array.fill(10)(rng.nextInt(10))) + }) + val schema = StructType(Seq( + StructField("arr", DataTypes.createArrayType(DataTypes.IntegerType)) + )) + val df = spark.createDataFrame(data, schema) + val exploded = df.select(struct(col("*")).as("star"), explode(col("arr")).as("a")) + val joined = exploded.join(exploded, "a").drop("a").distinct() + joined.show() + } } case class Generic[T](id: T, value: Double)