3030from threading import Thread
3131import warnings
3232import heapq
33+ import bisect
3334from random import Random
3435from math import sqrt , log
3536
@@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
574575 # noqa
575576
576577 >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
578+ >>> sc.parallelize(tmp).sortByKey(True, 1).collect()
579+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
577580 >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
578581 [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
579582 >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
@@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
584587 if numPartitions is None :
585588 numPartitions = self ._defaultReducePartitions ()
586589
587- bounds = list ()
590+ if numPartitions == 1 :
591+ if self .getNumPartitions () > 1 :
592+ self = self .coalesce (1 )
593+
594+ def sort (iterator ):
595+ return sorted (iterator , reverse = (not ascending ), key = lambda (k , v ): keyfunc (k ))
596+
597+ return self .mapPartitions (sort )
588598
589599 # first compute the boundary of each part via sampling: we want to partition
590600 # the key-space into bins such that the bins have roughly the same
591601 # number of (key, value) pairs falling into them
592- if numPartitions > 1 :
593- rddSize = self .count ()
594- # constant from Spark's RangePartitioner
595- maxSampleSize = numPartitions * 20.0
596- fraction = min (maxSampleSize / max (rddSize , 1 ), 1.0 )
597-
598- samples = self .sample (False , fraction , 1 ).map (
599- lambda (k , v ): k ).collect ()
600- samples = sorted (samples , reverse = (not ascending ), key = keyfunc )
601-
602- # we have numPartitions many parts but one of the them has
603- # an implicit boundary
604- for i in range (0 , numPartitions - 1 ):
605- index = (len (samples ) - 1 ) * (i + 1 ) / numPartitions
606- bounds .append (samples [index ])
602+ rddSize = self .count ()
603+ maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
604+ fraction = min (maxSampleSize / max (rddSize , 1 ), 1.0 )
605+ samples = self .sample (False , fraction , 1 ).map (lambda (k , v ): k ).collect ()
606+ samples = sorted (samples , reverse = (not ascending ), key = keyfunc )
607+
608+ # we have numPartitions many parts but one of the them has
609+ # an implicit boundary
610+ bounds = [samples [len (samples ) * (i + 1 ) / numPartitions ]
611+ for i in range (0 , numPartitions - 1 )]
607612
608613 def rangePartitionFunc (k ):
609- p = 0
610- while p < len (bounds ) and keyfunc (k ) > bounds [p ]:
611- p += 1
614+ p = bisect .bisect_left (bounds , keyfunc (k ))
612615 if ascending :
613616 return p
614617 else :
615618 return numPartitions - 1 - p
616619
617620 def mapFunc (iterator ):
618- yield sorted (iterator , reverse = (not ascending ), key = lambda (k , v ): keyfunc (k ))
621+ return sorted (iterator , reverse = (not ascending ), key = lambda (k , v ): keyfunc (k ))
619622
620- return (self .partitionBy (numPartitions , partitionFunc = rangePartitionFunc )
621- .mapPartitions (mapFunc , preservesPartitioning = True )
622- .flatMap (lambda x : x , preservesPartitioning = True ))
623+ return self .partitionBy (numPartitions , rangePartitionFunc ).mapPartitions (mapFunc , True )
623624
624625 def sortBy (self , keyfunc , ascending = True , numPartitions = None ):
625626 """
0 commit comments