diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..64125bca1be 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3571,6 +3571,9 @@ def shuffle( if len(self) == 0: return self + if keep_in_memory and indices_cache_file_name is not None: + raise ValueError("Please use either `keep_in_memory` or `indices_cache_file_name` but not both.") + if seed is not None and generator is not None: raise ValueError("Both `seed` and `generator` were provided. Please specify just one of them.") @@ -3600,7 +3603,7 @@ def shuffle( return self.select( indices=permutation, keep_in_memory=keep_in_memory, - indices_cache_file_name=indices_cache_file_name, + indices_cache_file_name=indices_cache_file_name if not keep_in_memory else None, writer_batch_size=writer_batch_size, new_fingerprint=new_fingerprint, ) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..25ae52e4da1 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1800,6 +1800,15 @@ def test_shuffle(self, in_memory): with self._create_dummy_dataset(in_memory, tmp_dir) as dset: tmp_file = os.path.join(tmp_dir, "test.arrow") fingerprint = dset._fingerprint + + with dset.shuffle(seed=1234, keep_in_memory=True) as dset_shuffled: + self.assertEqual(len(dset_shuffled), 30) + self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28") + self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")})) + self.assertNotEqual(dset_shuffled._fingerprint, fingerprint) + with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled: self.assertEqual(len(dset_shuffled), 30) self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28") @@ -1813,13 +1822,13 @@ def test_shuffle(self, in_memory): with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled_2: self.assertListEqual(dset_shuffled["filename"], dset_shuffled_2["filename"]) - # Compatible with temp_seed - with temp_seed(42), dset.shuffle() as d1: - with temp_seed(42), dset.shuffle() as d2, dset.shuffle() as d3: - self.assertListEqual(d1["filename"], d2["filename"]) - self.assertEqual(d1._fingerprint, d2._fingerprint) - self.assertNotEqual(d3["filename"], d2["filename"]) - self.assertNotEqual(d3._fingerprint, d2._fingerprint) + # Compatible with temp_seed + with temp_seed(42), dset.shuffle() as d1: + with temp_seed(42), dset.shuffle() as d2, dset.shuffle() as d3: + self.assertListEqual(d1["filename"], d2["filename"]) + self.assertEqual(d1._fingerprint, d2._fingerprint) + self.assertNotEqual(d3["filename"], d2["filename"]) + self.assertNotEqual(d3._fingerprint, d2._fingerprint) def test_sort(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: