diff --git a/src/datasets/config.py b/src/datasets/config.py index 0b80b1a2c87..9ed22ea3760 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -132,6 +132,10 @@ DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(HF_DATASETS_CACHE, DOWNLOADED_DATASETS_DIR) DOWNLOADED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_DOWNLOADED_DATASETS_PATH", DEFAULT_DOWNLOADED_DATASETS_PATH)) +EXTRACTED_DATASETS_DIR = "extracted" +DEFAULT_EXTRACTED_DATASETS_PATH = os.path.join(DEFAULT_DOWNLOADED_DATASETS_PATH, EXTRACTED_DATASETS_DIR) +EXTRACTED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_EXTRACTED_DATASETS_PATH", DEFAULT_EXTRACTED_DATASETS_PATH)) + # Batch size constants. For more info, see: # https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations) DEFAULT_MAX_BATCH_SIZE = 10_000 diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 6e2e907f15d..9aabbd49378 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -312,9 +312,15 @@ def cached_path( return output_path # Path where we extract compressed archives - # We extract in the cache dir, and get the extracted path name by hashing the original path" + # We extract in the cache dir, and get the extracted path name by hashing the original path abs_output_path = os.path.abspath(output_path) - output_path_extracted = os.path.join(cache_dir, "extracted", hash_url_to_filename(abs_output_path)) + output_path_extracted = ( + os.path.join( + download_config.cache_dir, config.EXTRACTED_DATASETS_DIR, hash_url_to_filename(abs_output_path) + ) + if download_config.cache_dir + else os.path.join(config.EXTRACTED_DATASETS_PATH, hash_url_to_filename(abs_output_path)) + ) if ( os.path.isdir(output_path_extracted) diff --git a/tests/conftest.py b/tests/conftest.py index c1cc0a1eef9..034cae1ed8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,8 @@ def set_test_cache_config(tmp_path_factory, monkeypatch): monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", str(test_hf_modules_cache)) test_downloaded_datasets_path = test_hf_datasets_cache / "downloads" monkeypatch.setattr("datasets.config.DOWNLOADED_DATASETS_PATH", str(test_downloaded_datasets_path)) + test_extracted_datasets_path = test_hf_datasets_cache / "downloads" / "extracted" + monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(test_extracted_datasets_path)) FILE_CONTENT = """\ diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 978d80c7742..38052eea26b 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -84,6 +84,29 @@ def test_cached_path_extract(xz_file, tmp_path, text_file): assert extracted_file_content == expected_file_content +@pytest.mark.parametrize("default_extracted", [True, False]) +@pytest.mark.parametrize("default_cache_dir", [True, False]) +def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch): + custom_cache_dir = "custom_cache" + custom_extracted_dir = "custom_extracted_dir" + custom_extracted_path = tmp_path / "custom_extracted_path" + if default_extracted: + expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted") + else: + monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir) + monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path)) + expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir) + + filename = xz_file + download_config = ( + DownloadConfig(extract_compressed_file=True) + if default_cache_dir + else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True) + ) + extracted_file_path = cached_path(filename, download_config=download_config) + assert Path(extracted_file_path).parent.parts[-2:] == expected + + def test_cached_path_local(text_file): # absolute path text_file = str(Path(text_file).resolve())