Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@
DEFAULT_HF_MODULES_CACHE = os.path.join(HF_CACHE_HOME, "modules")
HF_MODULES_CACHE = Path(os.getenv("HF_MODULES_CACHE", DEFAULT_HF_MODULES_CACHE))

EXTRACTED_DATASETS_DIR = "extracted"
DEFAULT_EXTRACTED_DATASETS_PATH = os.path.join(HF_DATASETS_CACHE, "downloads", EXTRACTED_DATASETS_DIR)
EXTRACTED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_EXTRACTED_DATASETS_PATH", DEFAULT_EXTRACTED_DATASETS_PATH))

# Batch size constants. For more info, see:
# https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
DEFAULT_MAX_BATCH_SIZE = 10_000
Expand Down
10 changes: 8 additions & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,15 @@ def cached_path(
return output_path

# Path where we extract compressed archives
# We extract in the cache dir, and get the extracted path name by hashing the original path"
# We extract in the cache dir, and get the extracted path name by hashing the original path
abs_output_path = os.path.abspath(output_path)
output_path_extracted = os.path.join(cache_dir, "extracted", hash_url_to_filename(abs_output_path))
output_path_extracted = (
os.path.join(
download_config.cache_dir, config.EXTRACTED_DATASETS_DIR, hash_url_to_filename(abs_output_path)
)
if download_config.cache_dir
else os.path.join(config.EXTRACTED_DATASETS_PATH, hash_url_to_filename(abs_output_path))
)

if (
os.path.isdir(output_path_extracted)
Expand Down
14 changes: 8 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
def set_test_cache_config(tmp_path_factory, monkeypatch):
# test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work?
test_hf_cache_home = tmp_path_factory.getbasetemp() / "cache"
test_hf_datasets_cache = str(test_hf_cache_home / "datasets")
test_hf_metrics_cache = str(test_hf_cache_home / "metrics")
test_hf_modules_cache = str(test_hf_cache_home / "modules")
monkeypatch.setattr("datasets.config.HF_DATASETS_CACHE", test_hf_datasets_cache)
monkeypatch.setattr("datasets.config.HF_METRICS_CACHE", test_hf_metrics_cache)
monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", test_hf_modules_cache)
test_hf_datasets_cache = test_hf_cache_home / "datasets"
test_hf_metrics_cache = test_hf_cache_home / "metrics"
test_hf_modules_cache = test_hf_cache_home / "modules"
monkeypatch.setattr("datasets.config.HF_DATASETS_CACHE", str(test_hf_datasets_cache))
monkeypatch.setattr("datasets.config.HF_METRICS_CACHE", str(test_hf_metrics_cache))
monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", str(test_hf_modules_cache))
test_extracted_datasets_path = test_hf_datasets_cache / "downloads" / "extracted"
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(test_extracted_datasets_path))


FILE_CONTENT = """\
Expand Down
23 changes: 23 additions & 0 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,29 @@ def test_cached_path_extract(xz_file, tmp_path, text_file):
assert extracted_file_content == expected_file_content


@pytest.mark.parametrize("default_extracted", [True, False])
@pytest.mark.parametrize("default_cache_dir", [True, False])
def test_extracted_datasets_path(default_extracted, default_cache_dir, xz_file, tmp_path, monkeypatch):
custom_cache_dir = "custom_cache"
custom_extracted_dir = "custom_extracted_dir"
custom_extracted_path = tmp_path / "custom_extracted_path"
if default_extracted:
expected = ("downloads" if default_cache_dir else custom_cache_dir, "extracted")
else:
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_DIR", custom_extracted_dir)
monkeypatch.setattr("datasets.config.EXTRACTED_DATASETS_PATH", str(custom_extracted_path))
expected = custom_extracted_path.parts[-2:] if default_cache_dir else (custom_cache_dir, custom_extracted_dir)

filename = xz_file
download_config = (
DownloadConfig(extract_compressed_file=True)
if default_cache_dir
else DownloadConfig(cache_dir=tmp_path / custom_cache_dir, extract_compressed_file=True)
)
extracted_file_path = cached_path(filename, download_config=download_config)
assert Path(extracted_file_path).parent.parts[-2:] == expected


def test_cached_path_local(text_file):
# absolute path
text_file = str(Path(text_file).resolve())
Expand Down