@@ -160,8 +160,7 @@ class RandomSampler(Sampler):
160160 object which implemented :code:`__len__` to get indices as the range of :code:`dataset` length. Default None.
161161 replacement(bool, optional): If False, sample the whole dataset, If True,
162162 set :attr:`num_samples` for how many samples to draw. Default False.
163- num_samples(int, optional): set sample number to draw if :attr:`replacement`
164- is True, then it will take samples according to the number you set. Default None, disabled.
163+ num_samples(int, optional): set sample number to draw. Default None, which is set to the length of `data_source`.
165164 generator(Generator, optional): specify a generator to sample the :code:`data_source`. Default None, disabled.
166165
167166 Returns:
@@ -212,9 +211,10 @@ def __init__(
212211 f"replacement={ self .replacement } "
213212 )
214213
215- if self ._num_samples is not None and not replacement :
214+ if not self .replacement and self . num_samples > len ( self . data_source ) :
216215 raise ValueError (
217- "num_samples should not be specified while replacement is False"
216+ "num_samples should be smaller than or equal to length of data_source when replacement is False, "
217+ f"but got num_samples: { self .num_samples } > data_source: { len (self .data_source )} "
218218 )
219219
220220 if not isinstance (self .num_samples , int ) or self .num_samples <= 0 :
@@ -246,7 +246,7 @@ def __iter__(self):
246246 yield index
247247 else :
248248 for index in np .random .choice (
249- np .arange (n ), n , replace = False
249+ np .arange (n ), self . num_samples , replace = False
250250 ).tolist ():
251251 yield index
252252
0 commit comments