Skip to content

Commit 6a3037c

Browse files
lhoestqthomwolf
andauthored
Fix dataset_dict.shuffle with single seed (#1626)
* fix dataset_dict.shuffle with single seed * add seed alias * missing test * Update src/datasets/dataset_dict.py Co-authored-by: Thomas Wolf <[email protected]> Co-authored-by: Thomas Wolf <[email protected]>
1 parent 364ba14 commit 6a3037c

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/datasets/dataset_dict.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ def sort(
420420

421421
def shuffle(
422422
self,
423-
seeds: Optional[Dict[str, int]] = None,
423+
seeds: Optional[Union[int, Dict[str, int]]] = None,
424+
seed: Optional[int] = None,
424425
generators: Optional[Dict[str, np.random.Generator]] = None,
425426
keep_in_memory: bool = False,
426427
load_from_cache_file: bool = True,
@@ -434,10 +435,11 @@ def shuffle(
434435
You can either supply a NumPy BitGenerator to use, or a seed to initiate NumPy's default random generator (PCG64).
435436
436437
Args:
437-
seeds (Optional `Dict[str, int]`): A seed to initialize the default BitGenerator if ``generator=None``.
438+
seeds (Optional `Dict[str, int]` or `int`): A seed to initialize the default BitGenerator if ``generator=None``.
438439
If None, then fresh, unpredictable entropy will be pulled from the OS.
439440
If an int or array_like[ints] is passed, then it will be passed to SeedSequence to derive the initial BitGenerator state.
440-
You have to provide one :obj:`seed` per dataset in the dataset dictionary.
441+
You can provide one :obj:`seed` per dataset in the dataset dictionary.
442+
seed (Optional `int`): A seed to initialize the default BitGenerator if ``generator=None``. Alias for seeds (the seed argument has priority over seeds if both arguments are provided).
441443
generators (Optional `Dict[str, np.random.Generator]`): Numpy random Generator to use to compute the permutation of the dataset rows.
442444
If ``generator=None`` (default), uses np.random.default_rng (the default BitGenerator (PCG64) of NumPy).
443445
You have to provide one :obj:`generator` per dataset in the dataset dictionary.
@@ -451,8 +453,13 @@ def shuffle(
451453
Higher value gives smaller cache files, lower value consume less temporary memory while running `.map()`.
452454
"""
453455
self._check_values_type()
456+
if seed is not None and seeds is not None:
457+
raise ValueError("Please specify seed or seeds, but not both")
458+
seeds = seed if seed is not None else seeds
454459
if seeds is None:
455460
seeds = {k: None for k in self}
461+
elif not isinstance(seeds, dict):
462+
seeds = {k: seeds for k in self}
456463
if generators is None:
457464
generators = {k: None for k in self}
458465
if indices_cache_file_names is None:

tests/test_dataset_dict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,17 @@ def test_shuffle(self):
283283
seeds=seeds, indices_cache_file_names=indices_cache_file_names_3, load_from_cache_file=False
284284
)
285285
self.assertNotEqual(dsets_shuffled_3["train"]["filename"], dsets_shuffled_3["test"]["filename"])
286+
287+
# other input types
288+
dsets_shuffled_int = dsets.shuffle(42)
289+
dsets_shuffled_alias = dsets.shuffle(seed=42)
290+
dsets_shuffled_none = dsets.shuffle()
291+
self.assertEqual(len(dsets_shuffled_int["train"]), 30)
292+
self.assertEqual(len(dsets_shuffled_alias["train"]), 30)
293+
self.assertEqual(len(dsets_shuffled_none["train"]), 30)
294+
286295
del dsets, dsets_shuffled, dsets_shuffled_2, dsets_shuffled_3
296+
del dsets_shuffled_int, dsets_shuffled_alias, dsets_shuffled_none
287297

288298
def test_check_values_type(self):
289299
dsets = self._create_dummy_dataset_dict()

0 commit comments

Comments
 (0)