From 235565eefcd457b1aa6d3f488d2bbd0a3615e670 Mon Sep 17 00:00:00 2001 From: Mustapha AJEGHRIR Date: Thu, 6 Oct 2022 13:07:27 +0200 Subject: [PATCH 1/4] adding keep in memory --- src/datasets/arrow_dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..1c94dbeec9f 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3571,6 +3571,9 @@ def shuffle( if len(self) == 0: return self + if keep_in_memory and indices_cache_file_name is not None: + raise ValueError("Please use either `keep_in_memory` or `indices_cache_file_name` but not both.") + if seed is not None and generator is not None: raise ValueError("Both `seed` and `generator` were provided. Please specify just one of them.") @@ -3585,7 +3588,7 @@ def shuffle( generator = np.random.default_rng(seed) # Check if we've already cached this computation (indexed by a hash) - if self.cache_files: + if self.cache_files and not keep_in_memory: if indices_cache_file_name is None: # we create a unique hash from the function, current dataset file and the mapping args indices_cache_file_name = self._get_cache_file_path(new_fingerprint) From 2d37e1f4c4bfcfd8523dea6dca86682d1b268d33 Mon Sep 17 00:00:00 2001 From: mustapha ajeghrir <66799406+Mustapha-AJEGHRIR@users.noreply.github.com> Date: Thu, 6 Oct 2022 14:58:40 +0200 Subject: [PATCH 2/4] Update src/datasets/arrow_dataset.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mario Šaško --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 1c94dbeec9f..598949bf2e3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3588,7 +3588,7 @@ def shuffle( generator = np.random.default_rng(seed) # Check if we've already cached this computation (indexed by a hash) - if self.cache_files and not keep_in_memory: + if self.cache_files: if indices_cache_file_name is None: # we create a unique hash from the function, current dataset file and the mapping args indices_cache_file_name = self._get_cache_file_path(new_fingerprint) From fb79fc017286badb2d4096c3c6b8ae1b6d0135de Mon Sep 17 00:00:00 2001 From: Mustapha AJEGHRIR Date: Thu, 6 Oct 2022 15:15:33 +0200 Subject: [PATCH 3/4] adding keep in memory test + remove 'compatible with temp_seed' inside the dset_shuffled scope --- src/datasets/arrow_dataset.py | 2 +- tests/test_arrow_dataset.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 598949bf2e3..64125bca1be 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3603,7 +3603,7 @@ def shuffle( return self.select( indices=permutation, keep_in_memory=keep_in_memory, - indices_cache_file_name=indices_cache_file_name, + indices_cache_file_name=indices_cache_file_name if not keep_in_memory else None, writer_batch_size=writer_batch_size, new_fingerprint=new_fingerprint, ) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..007be22361f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1800,6 +1800,15 @@ def test_shuffle(self, in_memory): with self._create_dummy_dataset(in_memory, tmp_dir) as dset: tmp_file = os.path.join(tmp_dir, "test.arrow") fingerprint = dset._fingerprint + + with dset.shuffle(seed=1234, keep_in_memory=True) as dset_shuffled: + self.assertEqual(len(dset_shuffled), 30) + self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28") + self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")})) + self.assertNotEqual(dset_shuffled._fingerprint, fingerprint) + with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled: self.assertEqual(len(dset_shuffled), 30) self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28") @@ -1813,13 +1822,13 @@ def test_shuffle(self, in_memory): with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled_2: self.assertListEqual(dset_shuffled["filename"], dset_shuffled_2["filename"]) - # Compatible with temp_seed - with temp_seed(42), dset.shuffle() as d1: - with temp_seed(42), dset.shuffle() as d2, dset.shuffle() as d3: - self.assertListEqual(d1["filename"], d2["filename"]) - self.assertEqual(d1._fingerprint, d2._fingerprint) - self.assertNotEqual(d3["filename"], d2["filename"]) - self.assertNotEqual(d3._fingerprint, d2._fingerprint) + # Compatible with temp_seed + with temp_seed(42), dset.shuffle() as d1: + with temp_seed(42), dset.shuffle() as d2, dset.shuffle() as d3: + self.assertListEqual(d1["filename"], d2["filename"]) + self.assertEqual(d1._fingerprint, d2._fingerprint) + self.assertNotEqual(d3["filename"], d2["filename"]) + self.assertNotEqual(d3._fingerprint, d2._fingerprint) def test_sort(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: From c96d72dffea32152ead73dcad674f519f10d5bf7 Mon Sep 17 00:00:00 2001 From: Mustapha AJEGHRIR Date: Thu, 6 Oct 2022 15:16:37 +0200 Subject: [PATCH 4/4] make style --- tests/test_arrow_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 007be22361f..25ae52e4da1 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1800,7 +1800,7 @@ def test_shuffle(self, in_memory): with self._create_dummy_dataset(in_memory, tmp_dir) as dset: tmp_file = os.path.join(tmp_dir, "test.arrow") fingerprint = dset._fingerprint - + with dset.shuffle(seed=1234, keep_in_memory=True) as dset_shuffled: self.assertEqual(len(dset_shuffled), 30) self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28") @@ -1808,7 +1808,7 @@ def test_shuffle(self, in_memory): self.assertDictEqual(dset.features, Features({"filename": Value("string")})) self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")})) self.assertNotEqual(dset_shuffled._fingerprint, fingerprint) - + with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset_shuffled: self.assertEqual(len(dset_shuffled), 30) self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28")