Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,22 +853,6 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down Expand Up @@ -3917,6 +3901,44 @@ def foo(df):
def foo(k, v):
return k

def test_stopiteration_in_udf(self):
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

def foofoo(x, y):
raise StopIteration()

exc_message = "Caught StopIteration thrown from user's code; failing the task"
df = self.spark.range(0, 100)

# plain udf (test for SPARK-23754)
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.withColumn('v', udf(foo)('id')).collect
)

# pandas scalar udf
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.withColumn(
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
).collect
)

# pandas grouped map
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.groupBy('id').apply(
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
).collect
)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \
to_arrow_type, to_arrow_schema
from pyspark.util import fail_on_stopiteration

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -155,8 +154,7 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

func = fail_on_stopiteration(self.func)
wrapped_func = _wrap_function(sc, func, self.returnType)
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
37 changes: 22 additions & 15 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,27 +1270,34 @@ def test_pipe_functions(self):
self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
self.assertEqual([], rdd.pipe('grep 4').collect())

def test_stopiteration_in_client_code(self):
def test_stopiteration_in_user_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)

# the exception raised is non-deterministic
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
msg = "Caught StopIteration thrown from user's code; failing the task"

self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg,
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)

# these methods call the user function both in the driver and in the executor
# the exception raised is different according to where the StopIteration happens
# RuntimeError is raised if in the driver
# Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _exception_message(excp):
def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
prevents silent loss of data when 'f' is used in a for loop in Spark code
"""
def wrapper(*args, **kwargs):
try:
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.util import fail_on_stopiteration
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -122,13 +123,17 @@ def read_single_udf(pickleSer, infile, eval_type):
else:
row_func = chain(row_func, f)

# make sure StopIteration's raised in the user code are not ignored
# when they are processed in a for loop, raise them as RuntimeError's instead
func = fail_on_stopiteration(row_func)

# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type)
else:
return arg_offsets, wrap_udf(row_func, return_type)
return arg_offsets, wrap_udf(func, return_type)


def read_udfs(pickleSer, infile, eval_type):
Expand Down