Skip to content

Commit f755a48

Browse files
committed
re-raising StopIteration in user code
1 parent 8bfdb68 commit f755a48

6 files changed

Lines changed: 138 additions & 13 deletions

File tree

python/pyspark/rdd.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pyspark.shuffle import Aggregator, ExternalMerger, \
5454
get_used_memory, ExternalSorter, ExternalGroupBy
5555
from pyspark.traceback_utils import SCCallSiteSync
56+
from pyspark.util import fail_on_stopiteration
5657

5758

5859
__all__ = ["RDD"]
@@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False):
339340
[('a', 1), ('b', 1), ('c', 1)]
340341
"""
341342
def func(_, iterator):
342-
return map(f, iterator)
343+
return map(fail_on_stopiteration(f), iterator)
343344
return self.mapPartitionsWithIndex(func, preservesPartitioning)
344345

345346
def flatMap(self, f, preservesPartitioning=False):
@@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False):
354355
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
355356
"""
356357
def func(s, iterator):
357-
return chain.from_iterable(map(f, iterator))
358+
return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
358359
return self.mapPartitionsWithIndex(func, preservesPartitioning)
359360

360361
def mapPartitions(self, f, preservesPartitioning=False):
@@ -417,7 +418,7 @@ def filter(self, f):
417418
[2, 4]
418419
"""
419420
def func(iterator):
420-
return filter(f, iterator)
421+
return filter(fail_on_stopiteration(f), iterator)
421422
return self.mapPartitions(func, True)
422423

423424
def distinct(self, numPartitions=None):
@@ -847,6 +848,8 @@ def reduce(self, f):
847848
...
848849
ValueError: Can not reduce() empty RDD
849850
"""
851+
f = fail_on_stopiteration(f)
852+
850853
def func(iterator):
851854
iterator = iter(iterator)
852855
try:
@@ -918,6 +921,8 @@ def fold(self, zeroValue, op):
918921
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
919922
15
920923
"""
924+
op = fail_on_stopiteration(op)
925+
921926
def func(iterator):
922927
acc = zeroValue
923928
for obj in iterator:
@@ -950,6 +955,9 @@ def aggregate(self, zeroValue, seqOp, combOp):
950955
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
951956
(0, 0)
952957
"""
958+
seqOp = fail_on_stopiteration(seqOp)
959+
combOp = fail_on_stopiteration(combOp)
960+
953961
def func(iterator):
954962
acc = zeroValue
955963
for obj in iterator:
@@ -1628,6 +1636,7 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash):
16281636
>>> sorted(rdd.reduceByKey(add).collect())
16291637
[('a', 2), ('b', 1)]
16301638
"""
1639+
func = fail_on_stopiteration(func)
16311640
return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc)
16321641

16331642
def reduceByKeyLocally(self, func):

python/pyspark/shuffle.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pyspark.heapq3 as heapq
2929
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
3030
CompressedSerializer, AutoBatchedSerializer
31+
from pyspark.util import fail_on_stopiteration
3132

3233

3334
try:
@@ -94,9 +95,9 @@ class Aggregator(object):
9495
"""
9596

9697
def __init__(self, createCombiner, mergeValue, mergeCombiners):
97-
self.createCombiner = createCombiner
98-
self.mergeValue = mergeValue
99-
self.mergeCombiners = mergeCombiners
98+
self.createCombiner = fail_on_stopiteration(createCombiner)
99+
self.mergeValue = fail_on_stopiteration(mergeValue)
100+
self.mergeCombiners = fail_on_stopiteration(mergeCombiners)
100101

101102

102103
class SimpleAggregator(Aggregator):

python/pyspark/sql/tests.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,47 @@ def __call__(self, x):
900900
self.assertEqual(f, f_.func)
901901
self.assertEqual(return_type, f_.returnType)
902902

903+
def test_stopiteration_in_udf(self):
904+
return
905+
906+
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
907+
from py4j.protocol import Py4JJavaError
908+
909+
def do_test(action, *args, **kwargs):
910+
exc_message = "Caught StopIteration thrown from user's code; failing the task"
911+
with self.assertRaisesRegexp(Py4JJavaError, exc_message) as cm:
912+
action(*args, **kwargs)
913+
914+
def foo(x):
915+
raise StopIteration()
916+
917+
def foofoo(x, y):
918+
raise StopIteration()
919+
920+
df = self.spark.range(0, 100)
921+
922+
# plain udf (test for SPARK-23754)
923+
do_test(df.withColumn('v', udf(foo)('id')).show)
924+
925+
# pandas scalar udf
926+
do_test(df.withColumn(
927+
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
928+
).show)
929+
930+
# pandas grouped map
931+
do_test(df.groupBy('id').apply(
932+
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
933+
).show)
934+
935+
do_test(df.groupBy('id').apply(
936+
pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
937+
).show)
938+
939+
# pandas grouped agg
940+
do_test(df.groupBy('id').agg(
941+
pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
942+
).show)
943+
903944
def test_validate_column_types(self):
904945
from pyspark.sql.functions import udf, to_json
905946
from pyspark.sql.column import _to_java_column

python/pyspark/tests.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,37 @@ def gen_gs(N, step=1):
161161
self.assertEqual(k, len(vs))
162162
self.assertEqual(list(range(k)), list(vs))
163163

164+
def test_stopiteration_is_raised(self):
165+
166+
def stopit(*args, **kwargs):
167+
raise StopIteration()
168+
169+
def legit_create_combiner(x):
170+
return [x]
171+
172+
def legit_merge_value(x, y):
173+
return x.append(y) or x
174+
175+
def legit_merge_combiners(x, y):
176+
return x.extend(y) or x
177+
178+
data = [(x % 2, x) for x in range(100)]
179+
180+
# wrong create combiner
181+
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
182+
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
183+
m.mergeValues(data)
184+
185+
# wrong merge value
186+
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
187+
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
188+
m.mergeValues(data)
189+
190+
# wrong merge combiners
191+
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
192+
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
193+
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
194+
164195

165196
class SorterTests(unittest.TestCase):
166197
def test_in_memory_sort(self):
@@ -1260,6 +1291,28 @@ def test_pipe_unicode(self):
12601291
result = rdd.pipe('cat').collect()
12611292
self.assertEqual(data, result)
12621293

1294+
def test_stopiteration_in_user_code(self):
1295+
1296+
def stopit(*x):
1297+
raise StopIteration()
1298+
1299+
seq_rdd = self.sc.parallelize(range(10))
1300+
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
1301+
1302+
self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
1303+
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
1304+
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
1305+
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
1306+
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
1307+
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
1308+
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)
1309+
1310+
# the exception raised is non-deterministic
1311+
self.assertRaises((Py4JJavaError, RuntimeError),
1312+
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
1313+
self.assertRaises((Py4JJavaError, RuntimeError),
1314+
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
1315+
12631316

12641317
class ProfilerTests(PySparkTestCase):
12651318

python/pyspark/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,23 @@ def majorMinorVersion(sparkVersion):
8989
" version numbers.")
9090

9191

92+
def fail_on_stopiteration(f):
93+
"""
94+
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
95+
prevents silent loss of data when 'f' is used in a for loop
96+
"""
97+
def wrapper(*args, **kwargs):
98+
try:
99+
return f(*args, **kwargs)
100+
except StopIteration as exc:
101+
raise RuntimeError(
102+
"Caught StopIteration thrown from user's code; failing the task",
103+
exc
104+
)
105+
106+
return wrapper
107+
108+
92109
if __name__ == "__main__":
93110
import doctest
94111
(failure_count, test_count) = doctest.testmod()

python/pyspark/worker.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
3636
BatchedSerializer, ArrowStreamPandasSerializer
3737
from pyspark.sql.types import to_arrow_type
38-
from pyspark.util import _get_argspec
38+
from pyspark.util import _get_argspec, fail_on_stopiteration
3939
from pyspark import shuffle
4040

4141
pickleSer = PickleSerializer()
@@ -92,10 +92,9 @@ def verify_result_length(*a):
9292
return lambda *a: (verify_result_length(*a), arrow_return_type)
9393

9494

95-
def wrap_grouped_map_pandas_udf(f, return_type):
95+
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
9696
def wrapped(key_series, value_series):
9797
import pandas as pd
98-
argspec = _get_argspec(f)
9998

10099
if len(argspec.args) == 1:
101100
result = f(pd.concat(value_series, axis=1))
@@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type):
140139
else:
141140
row_func = chain(row_func, f)
142141

142+
# make sure StopIteration's raised in the user code are not
143+
# ignored, but re-raised as RuntimeError's
144+
func = fail_on_stopiteration(row_func)
145+
143146
# the last returnType will be the return type of UDF
144147
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
145-
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
148+
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
146149
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
147-
return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
150+
argspec = _get_argspec(row_func) # signature was lost when wrapping it
151+
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
148152
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
149-
return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
153+
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
150154
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
151-
return arg_offsets, wrap_udf(row_func, return_type)
155+
return arg_offsets, wrap_udf(func, return_type)
152156
else:
153157
raise ValueError("Unknown eval type: {}".format(eval_type))
154158

0 commit comments

Comments
 (0)