Skip to content

Commit f988d9e

Browse files
committed
Merge pull request #7 from justinuang/feature/fix_python_hang
[SPARK-6294] fix hang when call take() in JVM on PythonRDD
2 parents d4389e2 + b06616d commit f988d9e

3 files changed

Lines changed: 15 additions & 4 deletions

File tree

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ private[spark] class PythonRDD(
7575

7676
context.addTaskCompletionListener { context =>
7777
writerThread.shutdownOnTaskCompletion()
78-
writerThread.join()
7978
if (!reuse_worker || !released) {
8079
try {
8180
worker.close()
@@ -247,13 +246,17 @@ private[spark] class PythonRDD(
247246
} catch {
248247
case e: Exception if context.isCompleted || context.isInterrupted =>
249248
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
250-
Utils.tryLog(worker.shutdownOutput())
249+
if (!worker.isClosed) {
250+
Utils.tryLog(worker.shutdownOutput())
251+
}
251252

252253
case e: Exception =>
253254
// We must avoid throwing exceptions here, because the thread uncaught exception handler
254255
// will kill the whole executor (see org.apache.spark.executor.Executor).
255256
_exception = e
256-
Utils.tryLog(worker.shutdownOutput())
257+
if (!worker.isClosed) {
258+
Utils.tryLog(worker.shutdownOutput())
259+
}
257260
} finally {
258261
// Release memory used by this thread for shuffles
259262
env.shuffleMemoryManager.releaseMemoryForThisThread()

python/pyspark/daemon.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def worker(sock):
6161
except SystemExit as exc:
6262
exit_code = compute_real_exit_code(exc.code)
6363
finally:
64-
outfile.flush()
64+
try:
65+
outfile.flush()
66+
except Exception:
67+
pass
6568
return exit_code
6669

6770

python/pyspark/tests.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,11 @@ def test_narrow_dependency_in_join(self):
782782
jobId = tracker.getJobIdsForGroup("test4")[0]
783783
self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
784784

785+
# Regression test for SPARK-6294
786+
def test_take_on_jrdd(self):
787+
rdd = self.sc.parallelize(range(1 << 20)).map(lambda x: str(x))
788+
rdd._jrdd.first()
789+
785790

786791
class ProfilerTests(PySparkTestCase):
787792

0 commit comments

Comments
 (0)