-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from driver to executor #21467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
8bfdb68
f755a48
1981d8e
39ab167
8505de2
20d26a6
4cc2b5e
c60225a
7cb9556
9724640
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1291,27 +1291,31 @@ def test_pipe_unicode(self): | |
| result = rdd.pipe('cat').collect() | ||
| self.assertEqual(data, result) | ||
|
|
||
| def test_stopiteration_in_client_code(self): | ||
| def test_stopiteration_in_user_code(self): | ||
|
|
||
| def stopit(*x): | ||
| raise StopIteration() | ||
|
|
||
| seq_rdd = self.sc.parallelize(range(10)) | ||
| keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) | ||
| msg = "Caught StopIteration thrown from user's code; failing the task" | ||
|
|
||
| self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) | ||
| self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) | ||
| self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) | ||
| self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) | ||
| self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) | ||
| self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) | ||
| self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) | ||
| self.assertRaisesRegexp(Py4JJavaError, msg, | ||
| seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) | ||
|
|
||
| # the exception raised is non-deterministic | ||
|
||
| self.assertRaises((Py4JJavaError, RuntimeError), | ||
| seq_rdd.aggregate, 0, stopit, lambda *x: 1) | ||
| self.assertRaises((Py4JJavaError, RuntimeError), | ||
| seq_rdd.aggregate, 0, lambda *x: 1, stopit) | ||
| self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, | ||
| keyed_rdd.reduceByKeyLocally, stopit) | ||
| self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, | ||
| seq_rdd.aggregate, 0, stopit, lambda *x: 1) | ||
| self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, | ||
| seq_rdd.aggregate, 0, lambda *x: 1, stopit) | ||
|
|
||
|
|
||
| class ProfilerTests(PySparkTestCase): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,7 @@ | |
| write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ | ||
| BatchedSerializer, ArrowStreamPandasSerializer | ||
| from pyspark.sql.types import to_arrow_type | ||
| from pyspark.util import _get_argspec | ||
| from pyspark.util import _get_argspec, fail_on_stopiteration | ||
| from pyspark import shuffle | ||
|
|
||
| pickleSer = PickleSerializer() | ||
|
|
@@ -92,10 +92,9 @@ def verify_result_length(*a): | |
| return lambda *a: (verify_result_length(*a), arrow_return_type) | ||
|
|
||
|
|
||
| def wrap_grouped_map_pandas_udf(f, return_type): | ||
| def wrap_grouped_map_pandas_udf(f, return_type, argspec): | ||
| def wrapped(key_series, value_series): | ||
| import pandas as pd | ||
| argspec = _get_argspec(f) | ||
|
|
||
| if len(argspec.args) == 1: | ||
| result = f(pd.concat(value_series, axis=1)) | ||
|
|
@@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type): | |
| else: | ||
| row_func = chain(row_func, f) | ||
|
|
||
| # make sure StopIteration's raised in the user code are not | ||
| # ignored, but re-raised as RuntimeError's | ||
| func = fail_on_stopiteration(row_func) | ||
|
||
|
|
||
| # the last returnType will be the return type of UDF | ||
| if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: | ||
| return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) | ||
| return arg_offsets, wrap_scalar_pandas_udf(func, return_type) | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: | ||
| return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) | ||
| argspec = _get_argspec(row_func) # signature was lost when wrapping it | ||
| return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: | ||
| return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) | ||
| return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) | ||
| elif eval_type == PythonEvalType.SQL_BATCHED_UDF: | ||
| return arg_offsets, wrap_udf(row_func, return_type) | ||
| return arg_offsets, wrap_udf(func, return_type) | ||
| else: | ||
| raise ValueError("Unknown eval type: {}".format(eval_type)) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny nit: I would do:
or