From 90559c03ccbf57bdc99c94001a5a4e9583b16f3c Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Thu, 19 Sep 2019 14:33:05 +0800 Subject: [PATCH 1/7] [SPARK-21045][PYSPARK] Defensive check for exception info thrown by user. --- python/pyspark/testing/utils.py | 16 ++++++++++++++++ python/pyspark/tests/test_worker.py | 25 ++++++++++++++++++++++++- python/pyspark/worker.py | 14 ++++++++++++-- 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 2b42b898f9ed..c3ce23d821ed 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -18,6 +18,7 @@ import os import struct import sys +import threading import unittest from pyspark import SparkContext, SparkConf @@ -127,3 +128,18 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) else: return jars[0] + + +class ExecThread(threading.Thread): + """ A wrapper thread which stores exception info if any occurred. + """ + def __init__(self, target): + self.target = target + self.exception = None + threading.Thread.__init__(self) + + def run(self): + try: + self.target() + except Exception as e: # captures any exceptions + self.exception = e diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 18fde17f4a06..b5c204c6ccbc 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -1,3 +1,4 @@ +# -*- encoding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -28,7 +29,7 @@ from py4j.protocol import Py4JJavaError -from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest +from pyspark.testing.utils import ExecThread, ReusedPySparkTestCase, PySparkTestCase, QuietTest if sys.version_info[0] >= 3: xrange = range @@ -150,6 +151,28 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version + def test_python_exception_non_hanging(self): + """ + SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark. + """ + def f(): + raise Exception("exception with 中 and \xd6\xd0") + + def run(): + self.sc.parallelize([1]).map(lambda x: f()).count() + + t = ExecThread(target=run) + t.daemon = True + t.start() + t.join(10) + self.assertFalse(t.isAlive(), "Spark should not be blocked") + self.assertIsInstance(t.exception, Py4JJavaError) + if sys.version_info.major < 3: + # we have to use unicode here to avoid UnicodeDecodeError + self.assertRegexpMatches(unicode(t.exception).encode("utf-8"), "exception with 中") + else: + self.assertRegexpMatches(str(t.exception), "exception with 中") + class WorkerReuseTest(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 086202de2c68..0c4728c30f37 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -44,7 +44,7 @@ from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle -if sys.version >= '3': +if sys.version_info.major >= 3: basestring = str else: from itertools import imap as map # use iterator map by default @@ -598,8 +598,18 @@ def process(): process() except Exception: try: + exc_info = traceback.format_exc() + if sys.version_info.major < 3: + if isinstance(exc_info, unicode): + exc_info = exc_info.encode("utf-8") + else: + # exc_info may contains other encoding bytes, replace the invalid byte and + # convert it back to utf-8 again + exc_info = exc_info.decode("utf-8", "replace").encode("utf-8") + else: + exc_info = exc_info.encode("utf-8") write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc().encode("utf-8"), outfile) + write_with_length(exc_info, outfile) except IOError: # JVM close the socket pass From fb72447896f756405e9a79d53ff1b0dc3cadf7ea Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Thu, 19 Sep 2019 20:02:38 +0800 Subject: [PATCH 2/7] remove ExecThread and rdd operations happens in main thread. --- python/pyspark/testing/utils.py | 15 --------------- python/pyspark/tests/test_worker.py | 30 +++++++++++------------------ 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index c3ce23d821ed..8e64253ad781 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -128,18 +128,3 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) else: return jars[0] - - -class ExecThread(threading.Thread): - """ A wrapper thread which stores exception info if any occurred. - """ - def __init__(self, target): - self.target = target - self.exception = None - threading.Thread.__init__(self) - - def run(self): - try: - self.target() - except Exception as e: # captures any exceptions - self.exception = e diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index b5c204c6ccbc..ccbe21f3a6f3 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -29,7 +29,7 @@ from py4j.protocol import Py4JJavaError -from pyspark.testing.utils import ExecThread, ReusedPySparkTestCase, PySparkTestCase, QuietTest +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest if sys.version_info[0] >= 3: xrange = range @@ -152,26 +152,18 @@ def test_with_different_versions_of_python(self): self.sc.pythonVer = version def test_python_exception_non_hanging(self): - """ - SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark. - """ - def f(): - raise Exception("exception with 中 and \xd6\xd0") + # SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark. + try: + def f(): + raise Exception("exception with 中 and \xd6\xd0") - def run(): self.sc.parallelize([1]).map(lambda x: f()).count() - - t = ExecThread(target=run) - t.daemon = True - t.start() - t.join(10) - self.assertFalse(t.isAlive(), "Spark should not be blocked") - self.assertIsInstance(t.exception, Py4JJavaError) - if sys.version_info.major < 3: - # we have to use unicode here to avoid UnicodeDecodeError - self.assertRegexpMatches(unicode(t.exception).encode("utf-8"), "exception with 中") - else: - self.assertRegexpMatches(str(t.exception), "exception with 中") + except Py4JJavaError as e: + if sys.version_info.major < 3: + # we have to use unicode here to avoid UnicodeDecodeError + self.assertRegexpMatches(unicode(e).encode("utf-8"), "exception with 中") + else: + self.assertRegexpMatches(str(e), "exception with 中") class WorkerReuseTest(PySparkTestCase): From ff7f248a8ae9894f966c4cfad9b334d9d9832190 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Thu, 19 Sep 2019 20:38:18 +0800 Subject: [PATCH 3/7] remove unused import --- python/pyspark/testing/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 8e64253ad781..2b42b898f9ed 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -18,7 +18,6 @@ import os import struct import sys -import threading import unittest from pyspark import SparkContext, SparkConf From 42a9eb0caed3b3b693938dc6fb3480cbd1462998 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Thu, 19 Sep 2019 23:26:27 +0800 Subject: [PATCH 4/7] revert change. --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0c4728c30f37..e9102188d6ab 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -44,7 +44,7 @@ from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle -if sys.version_info.major >= 3: +if sys.version >= '3': basestring = str else: from itertools import imap as map # use iterator map by default From 065296617b22f7d83617933bfe3f2feb3fbae118 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Thu, 19 Sep 2019 23:43:49 +0800 Subject: [PATCH 5/7] revert change and define unicode for python3 --- python/pyspark/worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e9102188d6ab..775d538c17ee 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -46,6 +46,7 @@ if sys.version >= '3': basestring = str + unicode = str else: from itertools import imap as map # use iterator map by default From ffb4d29d44e5c386649185638ad8628d71704cbb Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Fri, 20 Sep 2019 00:07:49 +0800 Subject: [PATCH 6/7] make lint-python happy for python3 --- python/pyspark/sql/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 996b7dd59ce9..83afafdd8b13 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -18,6 +18,9 @@ import py4j import sys +if sys.version_info.major >= 3: + unicode = str + class CapturedException(Exception): def __init__(self, desc, stackTrace, cause=None): From d6ec7aebcb0c1a2bda41a8b6a7b80ec670ff5bdb Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Fri, 20 Sep 2019 23:17:31 +0800 Subject: [PATCH 7/7] Address comments. --- python/pyspark/worker.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 775d538c17ee..698193d6bdd8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -46,7 +46,6 @@ if sys.version >= '3': basestring = str - unicode = str else: from itertools import imap as map # use iterator map by default @@ -600,13 +599,10 @@ def process(): except Exception: try: exc_info = traceback.format_exc() - if sys.version_info.major < 3: - if isinstance(exc_info, unicode): - exc_info = exc_info.encode("utf-8") - else: - # exc_info may contains other encoding bytes, replace the invalid byte and - # convert it back to utf-8 again - exc_info = exc_info.decode("utf-8", "replace").encode("utf-8") + if isinstance(exc_info, bytes): + # exc_info may contains other encoding bytes, replace the invalid bytes and convert + # it back to utf-8 again + exc_info = exc_info.decode("utf-8", "replace").encode("utf-8") else: exc_info = exc_info.encode("utf-8") write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)