Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,22 +900,6 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down Expand Up @@ -4096,6 +4080,43 @@ def foo(df):
def foo(k, v, w):
return k

def test_stopiteration_in_udf(self):
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

def foofoo(x, y):
raise StopIteration()

exc_message = "Caught StopIteration thrown from user's code; failing the task"
df = self.spark.range(0, 100)

# plain udf (test for SPARK-23754)
self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn(
'v', udf(foo)('id')
).collect)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: I would do:

self.assertRaisesRegexp(
    Py4JJavaError, exc_message, df.withColumn('v', udf(foo)('id')).collect)

or

self.assertRaisesRegexp(
    Py4JJavaError,
    exc_message,
    df.withColumn('v', udf(foo)('id')).collect)


# pandas scalar udf
self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn(
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
).collect)

# pandas grouped map
self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply(
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
).collect)

self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply(
pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
).collect)

# pandas grouped agg
self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').agg(
pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
).collect)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
14 changes: 2 additions & 12 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
to_arrow_type, to_arrow_schema
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark.util import _get_argspec

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -157,17 +157,7 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

func = fail_on_stopiteration(self.func)

# for pandas UDFs the worker needs to know if the function takes
# one or two arguments, but the signature is lost when wrapping with
# fail_on_stopiteration, so we store it here
if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
func._argspec = _get_argspec(self.func)

wrapped_func = _wrap_function(sc, func, self.returnType)
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
28 changes: 16 additions & 12 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,27 +1291,31 @@ def test_pipe_unicode(self):
result = rdd.pipe('cat').collect()
self.assertEqual(data, result)

def test_stopiteration_in_client_code(self):
def test_stopiteration_in_user_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
msg = "Caught StopIteration thrown from user's code; failing the task"

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg,
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)

# the exception raised is non-deterministic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean? The exception is non-deterministic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I asked this before. He explained that the exception can be thrown in driver side or executor side too non-deterministically. We should clarify this comment. It's quite core fix. Let's clarify everything.

self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):
Expand Down
7 changes: 1 addition & 6 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ def _get_argspec(f):
"""
Get argspec of a function. Supports both Python 2 and Python 3.
"""

if hasattr(f, '_argspec'):
# only used for pandas UDF: they wrap the user function, losing its signature
# workers need this signature, so UDF saves it here
argspec = f._argspec
elif sys.version_info[0] < 3:
if sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
Expand Down
18 changes: 11 additions & 7 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.util import _get_argspec
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -92,10 +92,9 @@ def verify_result_length(*a):
return lambda *a: (verify_result_length(*a), arrow_return_type)


def wrap_grouped_map_pandas_udf(f, return_type):
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrapped(key_series, value_series):
import pandas as pd
argspec = _get_argspec(f)

if len(argspec.args) == 1:
result = f(pd.concat(value_series, axis=1))
Expand Down Expand Up @@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type):
else:
row_func = chain(row_func, f)

# make sure StopIteration's raised in the user code are not
# ignored, but re-raised as RuntimeError's
func = fail_on_stopiteration(row_func)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add few comments for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think row_func name was fine. Let's just leave it as was.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid the overhead of calling get_argspec even when it's not needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sure.


# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
argspec = _get_argspec(row_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(row_func, return_type)
return arg_offsets, wrap_udf(func, return_type)
else:
raise ValueError("Unknown eval type: {}".format(eval_type))

Expand Down