Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import struct
import sys
import threading
import unittest

from pyspark import SparkContext, SparkConf
Expand Down Expand Up @@ -127,3 +128,18 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
else:
return jars[0]


class ExecThread(threading.Thread):
""" A wrapper thread which stores exception info if any occurred.
"""
def __init__(self, target):
self.target = target
self.exception = None
threading.Thread.__init__(self)

def run(self):
try:
self.target()
except Exception as e: # captures any exceptions
self.exception = e
25 changes: 24 additions & 1 deletion python/pyspark/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
Expand Down Expand Up @@ -28,7 +29,7 @@

from py4j.protocol import Py4JJavaError

from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
from pyspark.testing.utils import ExecThread, ReusedPySparkTestCase, PySparkTestCase, QuietTest

if sys.version_info[0] >= 3:
xrange = range
Expand Down Expand Up @@ -150,6 +151,28 @@ def test_with_different_versions_of_python(self):
finally:
self.sc.pythonVer = version

def test_python_exception_non_hanging(self):
"""
SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark.
"""
def f():
raise Exception("exception with 中 and \xd6\xd0")

def run():
self.sc.parallelize([1]).map(lambda x: f()).count()

t = ExecThread(target=run)
t.daemon = True
t.start()
t.join(10)
self.assertFalse(t.isAlive(), "Spark should not be blocked")
self.assertIsInstance(t.exception, Py4JJavaError)
if sys.version_info.major < 3:
# we have to use unicode here to avoid UnicodeDecodeError
self.assertRegexpMatches(unicode(t.exception).encode("utf-8"), "exception with 中")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, str against Py4j exception doesn't properly handle non-ascii codes (py4j/py4j#308)

else:
self.assertRegexpMatches(str(t.exception), "exception with 中")


class WorkerReuseTest(PySparkTestCase):

Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle

if sys.version >= '3':
if sys.version_info.major >= 3:
basestring = str
else:
from itertools import imap as map # use iterator map by default
Expand Down Expand Up @@ -598,8 +598,18 @@ def process():
process()
except Exception:
try:
exc_info = traceback.format_exc()
if sys.version_info.major < 3:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, let's drop this right after we drop Python 2, which I will do right after Spark 3.

if isinstance(exc_info, unicode):
exc_info = exc_info.encode("utf-8")
else:
# exc_info may contains other encoding bytes, replace the invalid byte and
# convert it back to utf-8 again
exc_info = exc_info.decode("utf-8", "replace").encode("utf-8")
else:
exc_info = exc_info.encode("utf-8")
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc().encode("utf-8"), outfile)
write_with_length(exc_info, outfile)
except IOError:
# JVM close the socket
pass
Expand Down