Skip to content

Commit b5c51c8

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3074] [PySpark] support groupByKey() with single huge key
This patch change groupByKey() to use external sort based approach, so it can support single huge key. For example, it can group by a dataset including one hot key with 40 millions values (strings), using 500M memory for Python worker, finished in about 2 minutes. (it will need 6G memory in hash based approach). During groupByKey(), it will do in-memory groupBy first. If the dataset can not fit in memory, then data will be partitioned by hash. If one partition still can not fit in memory, it will switch to sort based groupBy(). Author: Davies Liu <[email protected]> Author: Davies Liu <[email protected]> Closes #1977 from davies/groupby and squashes the following commits: af3713a [Davies Liu] make sure it's iterator 67772dd [Davies Liu] fix tests e78c15c [Davies Liu] address comments 0b0fde8 [Davies Liu] address comments 0dcf320 [Davies Liu] address comments, rollback changes in ResultIterable e3b8eab [Davies Liu] fix narrow dependency 2a1857a [Davies Liu] typo d2f053b [Davies Liu] add repr for FlattedValuesSerializer c6a2f8d [Davies Liu] address comments 9e2df24 [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 2b9c261 [Davies Liu] fix typo in comments 70aadcd [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby a14b4bd [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby ab5515b [Davies Liu] Merge branch 'master' into groupby 651f891 [Davies Liu] simplify GroupByKey 1578f2e [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 1f69f93 [Davies Liu] fix tests 0d3395f [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 341f1e0 [Davies Liu] add comments, refactor 47918b8 [Davies Liu] remove unused code 6540948 [Davies Liu] address comments: 17f4ec6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby 4d4bc86 [Davies Liu] bugfix 8ef965e [Davies Liu] Merge branch 'master' into groupby fbc504a [Davies Liu] Merge branch 'master' into groupby 779ed03 [Davies Liu] fix merge conflict 2c1d05b [Davies Liu] refactor, minor turning b48cda5 [Davies Liu] Merge branch 'master' into groupby 85138e6 [Davies Liu] Merge branch 'master' into groupby acd8e1b [Davies Liu] fix memory when groupByKey().count() 905b233 [Davies Liu] Merge branch 'sort' into groupby 1f075ed [Davies Liu] Merge branch 'master' into sort 4b07d39 [Davies Liu] compress the data while spilling 0a081c6 [Davies Liu] Merge branch 'master' into groupby f157fe7 [Davies Liu] Merge branch 'sort' into groupby eb53ca6 [Davies Liu] Merge branch 'master' into sort b2dc3bf [Davies Liu] Merge branch 'sort' into groupby 644abaf [Davies Liu] add license in LICENSE 19f7873 [Davies Liu] improve tests 11ba318 [Davies Liu] typo 085aef8 [Davies Liu] Merge branch 'master' into groupby 3ee58e5 [Davies Liu] switch to sort based groupBy, based on size of data 1ea0669 [Davies Liu] choose sort based groupByKey() automatically b40bae7 [Davies Liu] bugfix efa23df [Davies Liu] refactor, add spark.shuffle.sort=False 250be4e [Davies Liu] flatten the combined values when dumping into disks d05060d [Davies Liu] group the same key before shuffle, reduce the comparison during sorting 083d842 [Davies Liu] sorted based groupByKey() 55602ee [Davies Liu] use external sort in sortBy() and sortByKey()
1 parent 9c67049 commit b5c51c8

File tree

6 files changed

+531
-143
lines changed

6 files changed

+531
-143
lines changed

python/pyspark/join.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def dispatch(seq):
4848
vbuf.append(v)
4949
elif n == 2:
5050
wbuf.append(v)
51-
return [(v, w) for v in vbuf for w in wbuf]
51+
return ((v, w) for v in vbuf for w in wbuf)
5252
return _do_python_join(rdd, other, numPartitions, dispatch)
5353

5454

@@ -62,7 +62,7 @@ def dispatch(seq):
6262
wbuf.append(v)
6363
if not vbuf:
6464
vbuf.append(None)
65-
return [(v, w) for v in vbuf for w in wbuf]
65+
return ((v, w) for v in vbuf for w in wbuf)
6666
return _do_python_join(rdd, other, numPartitions, dispatch)
6767

6868

@@ -76,7 +76,7 @@ def dispatch(seq):
7676
wbuf.append(v)
7777
if not wbuf:
7878
wbuf.append(None)
79-
return [(v, w) for v in vbuf for w in wbuf]
79+
return ((v, w) for v in vbuf for w in wbuf)
8080
return _do_python_join(rdd, other, numPartitions, dispatch)
8181

8282

@@ -104,8 +104,9 @@ def make_mapper(i):
104104
rdd_len = len(vrdds)
105105

106106
def dispatch(seq):
107-
bufs = [[] for i in range(rdd_len)]
108-
for (n, v) in seq:
107+
bufs = [[] for _ in range(rdd_len)]
108+
for n, v in seq:
109109
bufs[n].append(v)
110-
return tuple(map(ResultIterable, bufs))
110+
return tuple(ResultIterable(vs) for vs in bufs)
111+
111112
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)

python/pyspark/rdd.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from pyspark.storagelevel import StorageLevel
4242
from pyspark.resultiterable import ResultIterable
4343
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
44-
get_used_memory, ExternalSorter
44+
get_used_memory, ExternalSorter, ExternalGroupBy
4545
from pyspark.traceback_utils import SCCallSiteSync
4646

4747
from py4j.java_collections import ListConverter, MapConverter
@@ -573,8 +573,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
573573
if numPartitions is None:
574574
numPartitions = self._defaultReducePartitions()
575575

576-
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
577-
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
576+
spill = self._can_spill()
577+
memory = self._memory_limit()
578578
serializer = self._jrdd_deserializer
579579

580580
def sortPartition(iterator):
@@ -1699,10 +1699,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
16991699
numPartitions = self._defaultReducePartitions()
17001700

17011701
serializer = self.ctx.serializer
1702-
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
1703-
== 'true')
1704-
memory = _parse_memory(self.ctx._conf.get(
1705-
"spark.python.worker.memory", "512m"))
1702+
spill = self._can_spill()
1703+
memory = self._memory_limit()
17061704
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
17071705

17081706
def combineLocally(iterator):
@@ -1755,21 +1753,28 @@ def createZero():
17551753

17561754
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
17571755

1756+
def _can_spill(self):
1757+
return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
1758+
1759+
def _memory_limit(self):
1760+
return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
1761+
17581762
# TODO: support variant with custom partitioner
17591763
def groupByKey(self, numPartitions=None):
17601764
"""
17611765
Group the values for each key in the RDD into a single sequence.
1762-
Hash-partitions the resulting RDD with into numPartitions partitions.
1766+
Hash-partitions the resulting RDD with numPartitions partitions.
17631767
17641768
Note: If you are grouping in order to perform an aggregation (such as a
17651769
sum or average) over each key, using reduceByKey or aggregateByKey will
17661770
provide much better performance.
17671771
17681772
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
1769-
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
1773+
>>> sorted(x.groupByKey().mapValues(len).collect())
1774+
[('a', 2), ('b', 1)]
1775+
>>> sorted(x.groupByKey().mapValues(list).collect())
17701776
[('a', [1, 1]), ('b', [1])]
17711777
"""
1772-
17731778
def createCombiner(x):
17741779
return [x]
17751780

@@ -1781,8 +1786,27 @@ def mergeCombiners(a, b):
17811786
a.extend(b)
17821787
return a
17831788

1784-
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
1785-
numPartitions).mapValues(lambda x: ResultIterable(x))
1789+
spill = self._can_spill()
1790+
memory = self._memory_limit()
1791+
serializer = self._jrdd_deserializer
1792+
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
1793+
1794+
def combine(iterator):
1795+
merger = ExternalMerger(agg, memory * 0.9, serializer) \
1796+
if spill else InMemoryMerger(agg)
1797+
merger.mergeValues(iterator)
1798+
return merger.iteritems()
1799+
1800+
locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
1801+
shuffled = locally_combined.partitionBy(numPartitions)
1802+
1803+
def groupByKey(it):
1804+
merger = ExternalGroupBy(agg, memory, serializer)\
1805+
if spill else InMemoryMerger(agg)
1806+
merger.mergeCombiners(it)
1807+
return merger.iteritems()
1808+
1809+
return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
17861810

17871811
def flatMapValues(self, f):
17881812
"""

python/pyspark/resultiterable.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
# limitations under the License.
1616
#
1717

18-
__all__ = ["ResultIterable"]
19-
2018
import collections
2119

20+
__all__ = ["ResultIterable"]
21+
2222

2323
class ResultIterable(collections.Iterable):
2424

2525
"""
26-
A special result iterable. This is used because the standard iterator can not be pickled
26+
A special result iterable. This is used because the standard
27+
iterator can not be pickled
2728
"""
2829

2930
def __init__(self, data):

python/pyspark/serializers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,29 @@ def __repr__(self):
220220
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
221221

222222

223+
class FlattenedValuesSerializer(BatchedSerializer):
224+
225+
"""
226+
Serializes a stream of list of pairs, split the list of values
227+
which contain more than a certain number of objects to make them
228+
have similar sizes.
229+
"""
230+
def __init__(self, serializer, batchSize=10):
231+
BatchedSerializer.__init__(self, serializer, batchSize)
232+
233+
def _batched(self, iterator):
234+
n = self.batchSize
235+
for key, values in iterator:
236+
for i in xrange(0, len(values), n):
237+
yield key, values[i:i + n]
238+
239+
def load_stream(self, stream):
240+
return self.serializer.load_stream(stream)
241+
242+
def __repr__(self):
243+
return "FlattenedValuesSerializer(%d)" % self.batchSize
244+
245+
223246
class AutoBatchedSerializer(BatchedSerializer):
224247
"""
225248
Choose the size of batch automatically based on the size of object
@@ -251,7 +274,7 @@ def __eq__(self, other):
251274
return (isinstance(other, AutoBatchedSerializer) and
252275
other.serializer == self.serializer and other.bestSize == self.bestSize)
253276

254-
def __str__(self):
277+
def __repr__(self):
255278
return "AutoBatchedSerializer(%s)" % str(self.serializer)
256279

257280

0 commit comments

Comments
 (0)