Skip to content

Commit 0cf822f

Browse files
committed
follow up pr for SPARK-25921
1 parent 79b0548 commit 0cf822f

2 files changed

Lines changed: 44 additions & 25 deletions

File tree

python/pyspark/taskcontext.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ def __new__(cls):
4848
cls._taskContext = taskContext = object.__new__(cls)
4949
return taskContext
5050

51-
def __init__(self):
52-
"""Construct a TaskContext, use get instead"""
53-
pass
54-
5551
@classmethod
5652
def _getOrCreate(cls):
5753
"""Internal function to get or create global TaskContext."""
@@ -140,13 +136,13 @@ class BarrierTaskContext(TaskContext):
140136
_port = None
141137
_secret = None
142138

143-
def __init__(self):
144-
"""Construct a BarrierTaskContext, use get instead"""
145-
pass
146-
147139
@classmethod
148140
def _getOrCreate(cls):
149-
"""Internal function to get or create global BarrierTaskContext."""
141+
"""
142+
Internal function to get or create global BarrierTaskContext. We need to make sure
143+
BarrierTaskContext returns here because it needs in python worker reuse scenario,
144+
see SPARK-25921 for more details.
145+
"""
150146
if not isinstance(cls._taskContext, BarrierTaskContext):
151147
cls._taskContext = object.__new__(cls)
152148
return cls._taskContext

python/pyspark/tests/test_taskcontext.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
import time
2020

21-
from pyspark import SparkContext, TaskContext, BarrierTaskContext
21+
from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
2222
from pyspark.testing.utils import PySparkTestCase
2323

2424

@@ -118,21 +118,6 @@ def context_barrier(x):
118118
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
119119
self.assertTrue(max(times) - min(times) < 1)
120120

121-
def test_barrier_with_python_worker_reuse(self):
122-
"""
123-
Verify that BarrierTaskContext.barrier() with reused python worker.
124-
"""
125-
self.sc._conf.set("spark.python.work.reuse", "true")
126-
rdd = self.sc.parallelize(range(4), 4)
127-
# start a normal job first to start all worker
128-
result = rdd.map(lambda x: x ** 2).collect()
129-
self.assertEqual([0, 1, 4, 9], result)
130-
# make sure `spark.python.work.reuse=true`
131-
self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true")
132-
133-
# worker will be reused in this barrier job
134-
self.test_barrier()
135-
136121
def test_barrier_infos(self):
137122
"""
138123
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
@@ -149,6 +134,44 @@ def f(iterator):
149134
self.assertTrue(len(taskInfos[0]) == 4)
150135

151136

137+
class TaskContextTestsWithWorkerReuse(PySparkTestCase):
138+
139+
def setUp(self):
140+
self._old_sys_path = list(sys.path)
141+
class_name = self.__class__.__name__
142+
conf = SparkConf().set("spark.python.worker.reuse", "true")
143+
self.sc = SparkContext('local[2]', class_name, conf=conf)
144+
145+
def test_barrier_with_python_worker_reuse(self):
146+
"""
147+
Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
148+
reused python worker.
149+
"""
150+
import os
151+
self.sc._conf.set("spark.python.work.reuse", "true")
152+
# start a normal job first to start all workers and get all worker pids
153+
worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
154+
# the worker will reuse in this barrier job
155+
rdd = self.sc.parallelize(range(10), 2)
156+
157+
def f(iterator):
158+
yield sum(iterator)
159+
160+
def context_barrier(x):
161+
tc = BarrierTaskContext.get()
162+
time.sleep(random.randint(1, 10))
163+
tc.barrier()
164+
return (time.time(), os.getpid())
165+
166+
result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
167+
times = map(lambda x: x[0], result)
168+
pids = map(lambda x: x[1], result)
169+
# check both barrier and worker reuse effect
170+
self.assertTrue(max(times) - min(times) < 1)
171+
for pid in pids:
172+
self.assertTrue(pid in worker_pids)
173+
174+
152175
if __name__ == "__main__":
153176
import unittest
154177
from pyspark.tests.test_taskcontext import *

0 commit comments

Comments
 (0)