Skip to content

Commit 6540948

Browse files
committed
address comments:
1. make SameKey can be iterable multiple times 2. make SameKey picklable 3. more tests 4. mapPartitions() with preservePartitions=True
1 parent 17f4ec6 commit 6540948

5 files changed

Lines changed: 121 additions & 54 deletions

File tree

python/pyspark/join.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def dispatch(seq):
4848
vbuf.append(v)
4949
elif n == 2:
5050
wbuf.append(v)
51-
return [(v, w) for v in vbuf for w in wbuf]
51+
return ((v, w) for v in vbuf for w in wbuf)
5252
return _do_python_join(rdd, other, numPartitions, dispatch)
5353

5454

@@ -62,7 +62,7 @@ def dispatch(seq):
6262
wbuf.append(v)
6363
if not vbuf:
6464
vbuf.append(None)
65-
return [(v, w) for v in vbuf for w in wbuf]
65+
return ((v, w) for v in vbuf for w in wbuf)
6666
return _do_python_join(rdd, other, numPartitions, dispatch)
6767

6868

@@ -76,7 +76,7 @@ def dispatch(seq):
7676
wbuf.append(v)
7777
if not wbuf:
7878
wbuf.append(None)
79-
return [(v, w) for v in vbuf for w in wbuf]
79+
return ((v, w) for v in vbuf for w in wbuf)
8080
return _do_python_join(rdd, other, numPartitions, dispatch)
8181

8282

@@ -88,8 +88,9 @@ def make_mapper(i):
8888
rdd_len = len(vrdds)
8989

9090
def dispatch(seq):
91-
bufs = [[] for i in range(rdd_len)]
92-
for (n, v) in seq:
91+
bufs = [[] for _ in range(rdd_len)]
92+
for n, v in seq:
9393
bufs[n].append(v)
94-
return tuple(map(ResultIterable, bufs))
94+
return tuple(ResultIterable(vs) for vs in bufs)
95+
9596
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)

python/pyspark/rdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1619,7 +1619,7 @@ def groupByKey(it):
16191619
merger.mergeCombiners(it)
16201620
return merger.iteritems()
16211621

1622-
return shuffled.mapPartitions(groupByKey).mapValues(ResultIterable)
1622+
return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
16231623

16241624
def flatMapValues(self, f):
16251625
"""

python/pyspark/resultiterable.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
class ResultIterable(object):
2424

2525
"""
26-
A special result iterable. This is used because the standard iterator can not be pickled
26+
A special result iterable. This is used because the standard
27+
iterator can not be pickled
2728
"""
2829

2930
def __init__(self, it):
@@ -37,6 +38,3 @@ def __len__(self):
3738
return len(self.it)
3839
except TypeError:
3940
return sum(1 for _ in self.it)
40-
41-
def __reduce__(self):
42-
return (ResultIterable, (list(self.it),))

python/pyspark/shuffle.py

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -537,59 +537,73 @@ def __init__(self, key, value, iterator, groupBy):
537537
self.groupBy = groupBy
538538
self._file = None
539539
self._ser = None
540-
self._index = None
541540

542-
def __iter__(self):
543-
return self
544-
545-
def next(self):
546-
if self._index is None:
547-
# begin of iterator
548-
if self._file is not None:
549-
if self.values:
550-
self._spill()
551-
self._file.flush()
552-
self._file.seek(0)
553-
self._index = 0
554-
555-
if self._index >= len(self.values) and self._file is not None:
556-
# load next chunk of values from disk
557-
self.values = next(self._ser.load_stream(self._file))
558-
self._index = 0
559-
560-
if self._index < len(self.values):
561-
value = self.values[self._index]
562-
self._index += 1
563-
return value
541+
def __getstate__(self):
542+
sum(1 for _ in self) # try to read all the values
543+
if self._file is not None:
544+
f = os.fdopen(os.dup(self._file.fileno()))
545+
f.seek(0)
546+
bytes = f.read()
547+
else:
548+
bytes = ''
549+
return (self.key, bytes, self.values)
550+
551+
def __setstate__(self, item):
552+
self.key, bytes, self.values = item
553+
self.iterator = iter([])
554+
self.groupBy = None
555+
if bytes:
556+
self._open_file()
557+
self._file.write(bytes)
558+
else:
559+
self._file = None
560+
self._ser = None
564561

565-
key, value = next(self.iterator)
566-
if key == self.key:
567-
return value
562+
def __iter__(self):
563+
if self._file is not None:
564+
self._file.flush()
565+
with os.fdopen(os.dup(self._file.fileno()), 'r', 65536) as f:
566+
f.seek(0)
567+
for values in self._ser.load_stream(f):
568+
for v in values:
569+
yield v
570+
571+
for v in self.values:
572+
yield v
573+
574+
if self.groupBy and self.groupBy.next_item is None:
575+
for key, value in self.iterator:
576+
if key == self.key:
577+
self.append(value) # save it for next read
578+
yield value
579+
else:
580+
self.groupBy.next_item = (key, value)
581+
break
568582

569-
# push them back into groupBy
570-
self.groupBy.next_item = (key, value)
571-
raise StopIteration
583+
def __len__(self):
584+
return sum(1 for _ in self)
572585

573586
def append(self, value):
574-
if self._index is not None:
575-
raise ValueError("Can not append value while iterating")
576-
577587
self.values.append(value)
578588
# dump them into disk if the key is huge
579589
if len(self.values) >= 10240:
580590
self._spill()
581591

592+
def _open_file(self):
593+
dirs = _get_local_dirs("objects")
594+
d = dirs[id(self) % len(dirs)]
595+
if not os.path.exists(d):
596+
os.makedirs(d)
597+
p = os.path.join(d, str(id))
598+
self._file = open(p, "w+", 65536)
599+
self._ser = CompressedSerializer(PickleSerializer())
600+
os.unlink(p)
601+
582602
def _spill(self):
583603
""" dump the values into disk """
584604
global MemoryBytesSpilled, DiskBytesSpilled
585605
if self._file is None:
586-
dirs = _get_local_dirs("objects")
587-
d = dirs[id(self) % len(dirs)]
588-
if not os.path.exists(d):
589-
os.makedirs(d)
590-
p = os.path.join(d, str(id))
591-
self._file = open(p, "w+", 65536)
592-
self._ser = CompressedSerializer(PickleSerializer())
606+
self._open_file()
593607

594608
used_memory = get_used_memory()
595609
pos = self._file.tell()
@@ -600,6 +614,19 @@ def _spill(self):
600614
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
601615

602616

617+
class ChainedIterable(object):
618+
"""
619+
Pickable chained iterator
620+
"""
621+
def __init__(self, iterators):
622+
self.iterators = iterators
623+
624+
def __iter__(self):
625+
for vs in self.iterators:
626+
for v in vs:
627+
yield v
628+
629+
603630
class GroupByKey(object):
604631
"""
605632
group a sorted iterator into [(k1, it1), (k2, it2), ...]
@@ -719,7 +746,7 @@ def _merged_items(self, index, limit=0):
719746
# if the memory can not hold all the partition,
720747
# then use sort based merge. Because of compression,
721748
# the data on disks will be much smaller than needed memory
722-
if (size >> 20) > self.memory_limit / 10:
749+
if (size >> 20) >= self.memory_limit / 10:
723750
return self._sorted_items(index)
724751

725752
self.data = {}
@@ -750,8 +777,7 @@ def load_partition(j):
750777
sorter = ExternalSorter(self.memory_limit, ser)
751778
sorted_items = sorter.sorted(itertools.chain(*disk_items),
752779
key=operator.itemgetter(0))
753-
754-
return ((k, itertools.chain.from_iterable(vs)) for k, vs in GroupByKey(sorted_items))
780+
return ((k, ChainedIterable(vs)) for k, vs in GroupByKey(sorted_items))
755781

756782

757783
if __name__ == "__main__":

python/pyspark/tests.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import time
3232
import zipfile
3333
import random
34-
from platform import python_implementation
34+
import itertools
3535

3636
if sys.version_info[:2] <= (2, 6):
3737
import unittest2 as unittest
@@ -122,6 +122,35 @@ def test_huge_dataset(self):
122122
self.N * 10)
123123
m._cleanup()
124124

125+
def test_group_by_key(self):
126+
127+
def gen_data(N, step):
128+
for i in range(1, N + 1, step):
129+
for j in range(i * 10):
130+
yield (i, j)
131+
132+
def gen_gs(N, step=1):
133+
return shuffle.GroupByKey(gen_data(N, step))
134+
135+
self.assertEqual(1, len(list(gen_gs(1))))
136+
self.assertEqual(2, len(list(gen_gs(2))))
137+
self.assertEqual(100, len(list(gen_gs(100))))
138+
self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
139+
self.assertTrue(all(k * 10 == len(list(vs)) for k, vs in gen_gs(100)))
140+
141+
for k, vs in gen_gs(5002, 100):
142+
if k % 1000 == 1:
143+
self.assertEqual(range(k), list(itertools.islice(vs, k)))
144+
self.assertEqual(k * 10, sum(1 for _ in vs))
145+
self.assertEqual(range(k * 9, k * 10), list(itertools.islice(vs, k * 9, k * 10)))
146+
self.assertEqual(k * 10, sum(1 for _ in vs))
147+
148+
ser = PickleSerializer()
149+
l = ser.loads(ser.dumps(list(gen_gs(5002, 1000))))
150+
for k, vs in l:
151+
self.assertEqual(k * 10, len(vs))
152+
self.assertEqual(range(k * 10), list(vs))
153+
125154

126155
class TestSorter(unittest.TestCase):
127156
def test_in_memory_sort(self):
@@ -595,6 +624,19 @@ def test_distinct(self):
595624
self.assertEquals(result.getNumPartitions(), 5)
596625
self.assertEquals(result.count(), 3)
597626

627+
def test_external_group_by_key(self):
628+
self.sc._conf.set("spark.python.worker.memory", "5m")
629+
N = 200001
630+
kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
631+
gkv = kv.groupByKey().cache()
632+
self.assertEqual(3, gkv.count())
633+
filtered = gkv.filter(lambda (k, vs): k == 1)
634+
self.assertEqual(1, filtered.count())
635+
self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
636+
result = filtered.collect()[0][1]
637+
self.assertEqual(N/3, len(result))
638+
self.assertTrue(isinstance(result.it, shuffle.ChainedIterable))
639+
598640

599641
class TestSQL(PySparkTestCase):
600642

0 commit comments

Comments
 (0)