Skip to content

Commit 69c9104

Browse files
committed
Add unit test
1 parent 435ccff commit 69c9104

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

python/pyspark/sql/tests.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4931,6 +4931,30 @@ def foo3(key, pdf):
49314931
expected4 = udf3.func((), pdf)
49324932
self.assertPandasEqual(expected4, result4)
49334933

4934+
# Regression test for SPARK-24334
4935+
def test_memory_leak(self):
4936+
from pyspark.sql.functions import pandas_udf, col, PandasUDFType, array, lit, explode
4937+
4938+
# Have all data in a single executor thread so it can trigger the race condition easier
4939+
with self.sql_conf({'spark.sql.shuffle.partitions': 1}):
4940+
df = self.spark.range(0, 1000)
4941+
df = df.withColumn('id', array([lit(i) for i in range(0, 300)])) \
4942+
.withColumn('id', explode(col('id'))) \
4943+
.withColumn('v', array([lit(i) for i in range(0, 1000)]))
4944+
4945+
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
4946+
def foo(pdf):
4947+
# Throw exception in the UDF
4948+
xxx
4949+
return pdf
4950+
4951+
result = df.groupby('id').apply(foo)
4952+
4953+
with QuietTest(self.sc):
4954+
with self.assertRaises(py4j.protocol.Py4JJavaError) as context:
4955+
result.count()
4956+
self.assertTrue('Memory leaked' not in str(context.exception))
4957+
49344958

49354959
@unittest.skipIf(
49364960
not _have_pandas or not _have_pyarrow,

0 commit comments

Comments
 (0)