diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6137ed25a0dd..180a3e882dab 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -493,6 +493,14 @@ def getStart(split): return start0 + int((split * size / numSlices)) * step def f(split, iterator): + # it's an empty iterator here but we need this line for triggering the + # logic of signal handling in FramedSerializer.load_stream, for instance, + # SpecialLengths.END_OF_DATA_SECTION in _read_with_length. Since + # FramedSerializer.load_stream produces a generator, the control should + # at least be in that function once. Here we do it by explicitly converting + # the empty iterator to a list, thus make sure worker reuse takes effect. + # See more details in SPARK-26549. + assert len(list(iterator)) == 0 return xrange(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index a33b77d98341..a4f108f18e17 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -22,7 +22,7 @@ from py4j.protocol import Py4JJavaError -from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest if sys.version_info[0] >= 3: xrange = range @@ -145,6 +145,16 @@ def test_with_different_versions_of_python(self): self.sc.pythonVer = version +class WorkerReuseTest(PySparkTestCase): + + def test_reuse_worker_of_parallelize_xrange(self): + rdd = self.sc.parallelize(xrange(20), 8) + previous_pids = rdd.map(lambda x: os.getpid()).collect() + current_pids = rdd.map(lambda x: os.getpid()).collect() + for pid in current_pids: + self.assertTrue(pid in previous_pids) + + if __name__ == "__main__": import unittest from pyspark.tests.test_worker import *