@@ -537,59 +537,73 @@ def __init__(self, key, value, iterator, groupBy):
537537 self .groupBy = groupBy
538538 self ._file = None
539539 self ._ser = None
540- self ._index = None
541540
542- def __iter__ (self ):
543- return self
544-
545- def next (self ):
546- if self ._index is None :
547- # begin of iterator
548- if self ._file is not None :
549- if self .values :
550- self ._spill ()
551- self ._file .flush ()
552- self ._file .seek (0 )
553- self ._index = 0
554-
555- if self ._index >= len (self .values ) and self ._file is not None :
556- # load next chunk of values from disk
557- self .values = next (self ._ser .load_stream (self ._file ))
558- self ._index = 0
559-
560- if self ._index < len (self .values ):
561- value = self .values [self ._index ]
562- self ._index += 1
563- return value
541+ def __getstate__ (self ):
542+ sum (1 for _ in self ) # try to read all the values
543+ if self ._file is not None :
544+ f = os .fdopen (os .dup (self ._file .fileno ()))
545+ f .seek (0 )
546+ bytes = f .read ()
547+ else :
548+ bytes = ''
549+ return (self .key , bytes , self .values )
550+
551+ def __setstate__ (self , item ):
552+ self .key , bytes , self .values = item
553+ self .iterator = iter ([])
554+ self .groupBy = None
555+ if bytes :
556+ self ._open_file ()
557+ self ._file .write (bytes )
558+ else :
559+ self ._file = None
560+ self ._ser = None
564561
565- key , value = next (self .iterator )
566- if key == self .key :
567- return value
562+ def __iter__ (self ):
563+ if self ._file is not None :
564+ self ._file .flush ()
565+ with os .fdopen (os .dup (self ._file .fileno ()), 'r' , 65536 ) as f :
566+ f .seek (0 )
567+ for values in self ._ser .load_stream (f ):
568+ for v in values :
569+ yield v
570+
571+ for v in self .values :
572+ yield v
573+
574+ if self .groupBy and self .groupBy .next_item is None :
575+ for key , value in self .iterator :
576+ if key == self .key :
577+ self .append (value ) # save it for next read
578+ yield value
579+ else :
580+ self .groupBy .next_item = (key , value )
581+ break
568582
569- # push them back into groupBy
570- self .groupBy .next_item = (key , value )
571- raise StopIteration
583+ def __len__ (self ):
584+ return sum (1 for _ in self )
572585
573586 def append (self , value ):
574- if self ._index is not None :
575- raise ValueError ("Can not append value while iterating" )
576-
577587 self .values .append (value )
578588 # dump them into disk if the key is huge
579589 if len (self .values ) >= 10240 :
580590 self ._spill ()
581591
592+ def _open_file (self ):
593+ dirs = _get_local_dirs ("objects" )
594+ d = dirs [id (self ) % len (dirs )]
595+ if not os .path .exists (d ):
596+ os .makedirs (d )
597+ p = os .path .join (d , str (id ))
598+ self ._file = open (p , "w+" , 65536 )
599+ self ._ser = CompressedSerializer (PickleSerializer ())
600+ os .unlink (p )
601+
582602 def _spill (self ):
583603 """ dump the values into disk """
584604 global MemoryBytesSpilled , DiskBytesSpilled
585605 if self ._file is None :
586- dirs = _get_local_dirs ("objects" )
587- d = dirs [id (self ) % len (dirs )]
588- if not os .path .exists (d ):
589- os .makedirs (d )
590- p = os .path .join (d , str (id ))
591- self ._file = open (p , "w+" , 65536 )
592- self ._ser = CompressedSerializer (PickleSerializer ())
606+ self ._open_file ()
593607
594608 used_memory = get_used_memory ()
595609 pos = self ._file .tell ()
@@ -600,6 +614,19 @@ def _spill(self):
600614 MemoryBytesSpilled += (used_memory - get_used_memory ()) << 20
601615
602616
617+ class ChainedIterable (object ):
618+ """
619+ Pickable chained iterator
620+ """
621+ def __init__ (self , iterators ):
622+ self .iterators = iterators
623+
624+ def __iter__ (self ):
625+ for vs in self .iterators :
626+ for v in vs :
627+ yield v
628+
629+
603630class GroupByKey (object ):
604631 """
605632 group a sorted iterator into [(k1, it1), (k2, it2), ...]
@@ -719,7 +746,7 @@ def _merged_items(self, index, limit=0):
719746 # if the memory can not hold all the partition,
720747 # then use sort based merge. Because of compression,
721748 # the data on disks will be much smaller than needed memory
722- if (size >> 20 ) > self .memory_limit / 10 :
749+ if (size >> 20 ) >= self .memory_limit / 10 :
723750 return self ._sorted_items (index )
724751
725752 self .data = {}
@@ -750,8 +777,7 @@ def load_partition(j):
750777 sorter = ExternalSorter (self .memory_limit , ser )
751778 sorted_items = sorter .sorted (itertools .chain (* disk_items ),
752779 key = operator .itemgetter (0 ))
753-
754- return ((k , itertools .chain .from_iterable (vs )) for k , vs in GroupByKey (sorted_items ))
780+ return ((k , ChainedIterable (vs )) for k , vs in GroupByKey (sorted_items ))
755781
756782
757783if __name__ == "__main__" :
0 commit comments