Skip to content

Commit 1206ffb

Browse files
Set configurable extracted datasets path (#2487)
* Test extracted datasets path * Set configurable extracted datasets path * Fix style * Use DEFAULT_DOWNLOADED_DATASETS_PATH to define DEFAULT_EXTRACTED_DATASETS_PATH Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
1 parent cfb19fd commit 1206ffb

4 files changed

Lines changed: 37 additions & 2 deletions

File tree

src/datasets/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@
132132
DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(HF_DATASETS_CACHE, DOWNLOADED_DATASETS_DIR)
133133
DOWNLOADED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_DOWNLOADED_DATASETS_PATH", DEFAULT_DOWNLOADED_DATASETS_PATH))
134134

135+
EXTRACTED_DATASETS_DIR = "extracted"
136+
DEFAULT_EXTRACTED_DATASETS_PATH = os.path.join(DEFAULT_DOWNLOADED_DATASETS_PATH, EXTRACTED_DATASETS_DIR)
137+
EXTRACTED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_EXTRACTED_DATASETS_PATH", DEFAULT_EXTRACTED_DATASETS_PATH))
138+
135139
# Batch size constants. For more info, see:
136140
# https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
137141
DEFAULT_MAX_BATCH_SIZE = 10_000

src/datasets/utils/file_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,15 @@ def cached_path(
312312
return output_path
313313

314314
# Path where we extract compressed archives
315-
# We extract in the cache dir, and get the extracted path name by hashing the original path"
315+
# We extract in the cache dir, and get the extracted path name by hashing the original path
316316
abs_output_path = os.path.abspath(output_path)
317-
output_path_extracted = os.path.join(cache_dir, "extracted", hash_url_to_filename(abs_output_path))
317+
output_path_extracted = (
318+
os.path.join(
319+
download_config.cache_dir, config.EXTRACTED_DATASETS_DIR, hash_url_to_filename(abs_output_path)
320+
)
321+
if download_config.cache_dir
322+
else os.path.join(config.EXTRACTED_DATASETS_PATH, hash_url_to_filename(abs_output_path))
323+
)
318324

319325
if (
320326
os.path.isdir(output_path_extracted)

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def set_test_cache_config(tmp_path_factory, monkeypatch):
2323
monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", str(test_hf_modules_cache))
2424
test_downloaded_datasets_path = test_hf_datasets_cache / "downloads"
2525
monkeypatch.setattr("datasets.config.DOWNLOADED_DATASETS_PATH", str(test_downloaded_datasets_path))
26+
test_extracted_datasets_path = test_hf_datasets_cache / "downloads" / "extracted"
27+
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(test_extracted_datasets_path))
2628

2729

2830
FILE_CONTENT = """\

tests/test_file_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,29 @@ def test_cached_path_extract(xz_file, tmp_path, text_file):
8484
assert extracted_file_content == expected_file_content
8585

8686

87+
@pytest.mark.parametrize("default_extracted", [True, False])
88+
@pytest.mark.parametrize("default_cache_dir", [True, False])
89+
def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch):
90+
custom_cache_dir = "custom_cache"
91+
custom_extracted_dir = "custom_extracted_dir"
92+
custom_extracted_path = tmp_path / "custom_extracted_path"
93+
if default_extracted:
94+
expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted")
95+
else:
96+
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir)
97+
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path))
98+
expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir)
99+
100+
filename = xz_file
101+
download_config = (
102+
DownloadConfig(extract_compressed_file=True)
103+
if default_cache_dir
104+
else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True)
105+
)
106+
extracted_file_path = cached_path(filename, download_config=download_config)
107+
assert Path(extracted_file_path).parent.parts[-2:] == expected
108+
109+
87110
def test_cached_path_local(text_file):
88111
# absolute path
89112
text_file = str(Path(text_file).resolve())

0 commit comments

Comments
 (0)