Skip to content

Commit c1bacd9

Browse files
committed
fix seed distribution and add some tests for rdd.sample
1 parent 51ce997 commit c1bacd9

2 files changed

Lines changed: 22 additions & 8 deletions

File tree

python/pyspark/rdd.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,23 @@ def sample(self, withReplacement, fraction, seed=None):
317317
Return a sampled subset of this RDD (relies on numpy and falls back
318318
on default random generator if numpy is unavailable).
319319
320-
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
321-
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
320+
>>> rdd = sc.parallelize(range(0, 100), 4)
321+
>>> wo = rdd.sample(False, 0.1, 2).collect()
322+
>>> wo_dup = rdd.sample(False, 0.1, 2).collect()
323+
>>> set(wo) == set(wo_dup)
324+
True
325+
>>> wr = rdd.sample(True, 0.2, 5).collect()
326+
>>> wr_dup = rdd.sample(True, 0.2, 5).collect()
327+
>>> set(wr) == set(wr_dup)
328+
True
329+
>>> wo_s10 = rdd.sample(False, 0.3, 10).collect()
330+
>>> wo_s20 = rdd.sample(False, 0.3, 20).collect()
331+
>>> set(wo_s10) != set(wo_s20)
332+
True
333+
>>> wr_s11 = rdd.sample(True, 0.4, 11).collect()
334+
>>> wr_s21 = rdd.sample(True, 0.4, 21).collect()
335+
>>> set(wr_s11) != set(wr_s21)
336+
True
322337
"""
323338
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
324339
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

python/pyspark/rddsampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None):
4040
def initRandomGenerator(self, split):
4141
if self._use_numpy:
4242
import numpy
43-
self._random = numpy.random.RandomState(self._seed)
43+
self._random = numpy.random.RandomState(self._seed ^ split)
4444
else:
45-
self._random = random.Random(self._seed)
45+
self._random = random.Random(self._seed ^ split)
4646

47-
for _ in range(0, split):
48-
# discard the next few values in the sequence to have a
49-
# different seed for the different splits
50-
self._random.randint(0, 2 ** 32 - 1)
47+
# mixing because the initial seeds are close to each other
48+
for _ in xrange(10):
49+
self._random.randint(0, 1)
5150

5251
self._split = split
5352
self._rand_initialized = True

0 commit comments

Comments
 (0)