1818import sys
1919import time
2020
21- from pyspark import SparkContext , TaskContext , BarrierTaskContext
21+ from pyspark import SparkConf , SparkContext , TaskContext , BarrierTaskContext
2222from pyspark .testing .utils import PySparkTestCase
2323
2424
@@ -118,21 +118,6 @@ def context_barrier(x):
118118 times = rdd .barrier ().mapPartitions (f ).map (context_barrier ).collect ()
119119 self .assertTrue (max (times ) - min (times ) < 1 )
120120
121- def test_barrier_with_python_worker_reuse (self ):
122- """
123- Verify that BarrierTaskContext.barrier() with reused python worker.
124- """
125- self .sc ._conf .set ("spark.python.work.reuse" , "true" )
126- rdd = self .sc .parallelize (range (4 ), 4 )
127- # start a normal job first to start all worker
128- result = rdd .map (lambda x : x ** 2 ).collect ()
129- self .assertEqual ([0 , 1 , 4 , 9 ], result )
130- # make sure `spark.python.work.reuse=true`
131- self .assertEqual (self .sc ._conf .get ("spark.python.work.reuse" ), "true" )
132-
133- # worker will be reused in this barrier job
134- self .test_barrier ()
135-
136121 def test_barrier_infos (self ):
137122 """
138123 Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
@@ -149,6 +134,44 @@ def f(iterator):
149134 self .assertTrue (len (taskInfos [0 ]) == 4 )
150135
151136
137+ class TaskContextTestsWithWorkerReuse (PySparkTestCase ):
138+
139+ def setUp (self ):
140+ self ._old_sys_path = list (sys .path )
141+ class_name = self .__class__ .__name__
142+ conf = SparkConf ().set ("spark.python.worker.reuse" , "true" )
143+ self .sc = SparkContext ('local[2]' , class_name , conf = conf )
144+
145+ def test_barrier_with_python_worker_reuse (self ):
146+ """
147+ Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
148+ reused python worker.
149+ """
150+ import os
151+ self .sc ._conf .set ("spark.python.work.reuse" , "true" )
152+ # start a normal job first to start all workers and get all worker pids
153+ worker_pids = self .sc .parallelize (range (2 ), 2 ).map (lambda x : os .getpid ()).collect ()
154+ # the worker will reuse in this barrier job
155+ rdd = self .sc .parallelize (range (10 ), 2 )
156+
157+ def f (iterator ):
158+ yield sum (iterator )
159+
160+ def context_barrier (x ):
161+ tc = BarrierTaskContext .get ()
162+ time .sleep (random .randint (1 , 10 ))
163+ tc .barrier ()
164+ return (time .time (), os .getpid ())
165+
166+ result = rdd .barrier ().mapPartitions (f ).map (context_barrier ).collect ()
167+ times = map (lambda x : x [0 ], result )
168+ pids = map (lambda x : x [1 ], result )
169+ # check both barrier and worker reuse effect
170+ self .assertTrue (max (times ) - min (times ) < 1 )
171+ for pid in pids :
172+ self .assertTrue (pid in worker_pids )
173+
174+
152175if __name__ == "__main__" :
153176 import unittest
154177 from pyspark .tests .test_taskcontext import *
0 commit comments