Skip to content

Commit 434bea1

Browse files
daviesmateiz
authored andcommitted
[SPARK-2983] [PySpark] improve performance of sortByKey()
1. skip partitionBy() when numOfPartition is 1 2. use bisect_left (O(lg(N))) instread of loop (O(N)) in rangePartitioner Author: Davies Liu <[email protected]> Closes #1898 from davies/sort and squashes the following commits: 0a9608b [Davies Liu] Merge branch 'master' into sort 1cf9565 [Davies Liu] improve performance of sortByKey()
1 parent c974a71 commit 434bea1

1 file changed

Lines changed: 24 additions & 23 deletions

File tree

python/pyspark/rdd.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from threading import Thread
3131
import warnings
3232
import heapq
33+
import bisect
3334
from random import Random
3435
from math import sqrt, log
3536

@@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
574575
# noqa
575576
576577
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
578+
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
579+
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
577580
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
578581
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
579582
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
@@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
584587
if numPartitions is None:
585588
numPartitions = self._defaultReducePartitions()
586589

587-
bounds = list()
590+
if numPartitions == 1:
591+
if self.getNumPartitions() > 1:
592+
self = self.coalesce(1)
593+
594+
def sort(iterator):
595+
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
596+
597+
return self.mapPartitions(sort)
588598

589599
# first compute the boundary of each part via sampling: we want to partition
590600
# the key-space into bins such that the bins have roughly the same
591601
# number of (key, value) pairs falling into them
592-
if numPartitions > 1:
593-
rddSize = self.count()
594-
# constant from Spark's RangePartitioner
595-
maxSampleSize = numPartitions * 20.0
596-
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
597-
598-
samples = self.sample(False, fraction, 1).map(
599-
lambda (k, v): k).collect()
600-
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
601-
602-
# we have numPartitions many parts but one of the them has
603-
# an implicit boundary
604-
for i in range(0, numPartitions - 1):
605-
index = (len(samples) - 1) * (i + 1) / numPartitions
606-
bounds.append(samples[index])
602+
rddSize = self.count()
603+
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
604+
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
605+
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
606+
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
607+
608+
# we have numPartitions many parts but one of the them has
609+
# an implicit boundary
610+
bounds = [samples[len(samples) * (i + 1) / numPartitions]
611+
for i in range(0, numPartitions - 1)]
607612

608613
def rangePartitionFunc(k):
609-
p = 0
610-
while p < len(bounds) and keyfunc(k) > bounds[p]:
611-
p += 1
614+
p = bisect.bisect_left(bounds, keyfunc(k))
612615
if ascending:
613616
return p
614617
else:
615618
return numPartitions - 1 - p
616619

617620
def mapFunc(iterator):
618-
yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
621+
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
619622

620-
return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
621-
.mapPartitions(mapFunc, preservesPartitioning=True)
622-
.flatMap(lambda x: x, preservesPartitioning=True))
623+
return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
623624

624625
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
625626
"""

0 commit comments

Comments
 (0)