diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7f038fb6164..cd2a3d57b63 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -682,6 +682,8 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs + str(e) ) + dl_manager.manage_extracted_files() + if verify_infos: verify_splits(self.info.splits, split_dict) diff --git a/src/datasets/utils/download_manager.py b/src/datasets/utils/download_manager.py index 5bd5f544e9b..33b40ebe995 100644 --- a/src/datasets/utils/download_manager.py +++ b/src/datasets/utils/download_manager.py @@ -290,3 +290,14 @@ def download_and_extract(self, url_or_urls): def get_recorded_sizes_checksums(self): return self._recorded_sizes_checksums.copy() + + def delete_extracted_files(self): + paths_to_delete = set(self.extracted_paths.values()) - set(self.downloaded_paths.values()) + for key, path in list(self.extracted_paths.items()): + if path in paths_to_delete and os.path.isfile(path): + os.remove(path) + del self.extracted_paths[key] + + def manage_extracted_files(self): + if self._download_config.delete_extracted: + self.delete_extracted_files() diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index f6ab0392a36..15cc5c7a7e5 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -217,6 +217,7 @@ class DownloadConfig: extract the compressed file in a folder along the archive. force_extract (:obj:`bool`, default ``False``): If True when extract_compressed_file is True and the archive was already extracted, re-extract the archive and override the folder where it was extracted. + delete_extracted (:obj:`bool`, default ``False``): Whether to delete (or keep) the extracted files. use_etag (:obj:`bool`, default ``True``): num_proc (:obj:`int`, optional): max_retries (:obj:`int`, default ``1``): The number of times to retry an HTTP request if it fails. @@ -232,6 +233,7 @@ class DownloadConfig: user_agent: Optional[str] = None extract_compressed_file: bool = False force_extract: bool = False + delete_extracted: bool = False use_etag: bool = True num_proc: Optional[int] = None max_retries: int = 1 diff --git a/src/datasets/utils/mock_download_manager.py b/src/datasets/utils/mock_download_manager.py index cbb0688eeb9..a3220146a04 100644 --- a/src/datasets/utils/mock_download_manager.py +++ b/src/datasets/utils/mock_download_manager.py @@ -201,3 +201,9 @@ def create_dummy_data_single(self, path_to_dummy_data, data_url): # while now we expected the dummy_data.zip file to be a directory containing # the downloaded file. return path_to_dummy_data + + def delete_extracted_files(self): + pass + + def manage_extracted_files(self): + pass diff --git a/tests/test_load.py b/tests/test_load.py index 803a7134d9b..68c6c87f27c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -19,6 +19,7 @@ from datasets.features import Features, Value from datasets.iterable_dataset import IterableDataset from datasets.load import prepare_module +from datasets.utils.file_utils import DownloadConfig from .utils import ( OfflineSimulationMode, @@ -345,3 +346,18 @@ def test_remote_data_files(): assert isinstance(ds, IterableDataset) ds_item = next(iter(ds)) assert ds_item.keys() == {"langs", "ner_tags", "spans", "tokens"} + + +@pytest.mark.parametrize("deleted", [False, True]) +def test_load_dataset_deletes_extracted_files(deleted, jsonl_gz_path, tmp_path): + data_files = jsonl_gz_path + cache_dir = tmp_path / "cache" + if deleted: + download_config = DownloadConfig(delete_extracted=True, cache_dir=cache_dir / "downloads") + ds = load_dataset( + "json", split="train", data_files=data_files, cache_dir=cache_dir, download_config=download_config + ) + else: # default + ds = load_dataset("json", split="train", data_files=data_files, cache_dir=cache_dir) + assert ds[0] == {"col_1": "0", "col_2": 0, "col_3": 0.0} + assert (sorted((cache_dir / "downloads" / "extracted").iterdir()) == []) is deleted