5353from pyspark .shuffle import Aggregator , ExternalMerger , \
5454 get_used_memory , ExternalSorter , ExternalGroupBy
5555from pyspark .traceback_utils import SCCallSiteSync
56+ from pyspark .util import fail_on_stopiteration
5657
5758
5859__all__ = ["RDD" ]
@@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False):
339340 [('a', 1), ('b', 1), ('c', 1)]
340341 """
341342 def func (_ , iterator ):
342- return map (f , iterator )
343+ return map (fail_on_stopiteration ( f ) , iterator )
343344 return self .mapPartitionsWithIndex (func , preservesPartitioning )
344345
345346 def flatMap (self , f , preservesPartitioning = False ):
@@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False):
354355 [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
355356 """
356357 def func (s , iterator ):
357- return chain .from_iterable (map (f , iterator ))
358+ return chain .from_iterable (map (fail_on_stopiteration ( f ) , iterator ))
358359 return self .mapPartitionsWithIndex (func , preservesPartitioning )
359360
360361 def mapPartitions (self , f , preservesPartitioning = False ):
@@ -417,7 +418,7 @@ def filter(self, f):
417418 [2, 4]
418419 """
419420 def func (iterator ):
420- return filter (f , iterator )
421+ return filter (fail_on_stopiteration ( f ) , iterator )
421422 return self .mapPartitions (func , True )
422423
423424 def distinct (self , numPartitions = None ):
@@ -847,6 +848,8 @@ def reduce(self, f):
847848 ...
848849 ValueError: Can not reduce() empty RDD
849850 """
851+ f = fail_on_stopiteration (f )
852+
850853 def func (iterator ):
851854 iterator = iter (iterator )
852855 try :
@@ -918,6 +921,8 @@ def fold(self, zeroValue, op):
918921 >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
919922 15
920923 """
924+ op = fail_on_stopiteration (op )
925+
921926 def func (iterator ):
922927 acc = zeroValue
923928 for obj in iterator :
@@ -950,6 +955,9 @@ def aggregate(self, zeroValue, seqOp, combOp):
950955 >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
951956 (0, 0)
952957 """
958+ seqOp = fail_on_stopiteration (seqOp )
959+ combOp = fail_on_stopiteration (combOp )
960+
953961 def func (iterator ):
954962 acc = zeroValue
955963 for obj in iterator :
@@ -1628,6 +1636,7 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash):
16281636 >>> sorted(rdd.reduceByKey(add).collect())
16291637 [('a', 2), ('b', 1)]
16301638 """
1639+ func = fail_on_stopiteration (func )
16311640 return self .combineByKey (lambda x : x , func , func , numPartitions , partitionFunc )
16321641
16331642 def reduceByKeyLocally (self , func ):
0 commit comments