Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ def __init__(

# prepare data dirs
self._cache_dir_root = os.path.expanduser(cache_dir or config.HF_DATASETS_CACHE)
self._cache_downloaded_dir = (
os.path.join(cache_dir, config.DOWNLOADED_DATASETS_DIR) if cache_dir else config.DOWNLOADED_DATASETS_PATH
)
self._cache_dir = self._build_cache_dir()
if not is_remote_url(self._cache_dir_root):
os.makedirs(self._cache_dir_root, exist_ok=True)
Expand Down Expand Up @@ -482,7 +485,7 @@ def download_and_prepare(
if dl_manager is None:
if download_config is None:
download_config = DownloadConfig(
cache_dir=os.path.join(self._cache_dir_root, "downloads"),
cache_dir=self._cache_downloaded_dir,
force_download=bool(download_mode == GenerateMode.FORCE_REDOWNLOAD),
use_etag=False,
use_auth_token=use_auth_token,
Expand Down
6 changes: 5 additions & 1 deletion src/datasets/commands/dummy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,11 @@ def run(self):
print(f"Automatic dummy data generation failed for some configs of '{self._path_to_dataset}'")

def _autogenerate_dummy_data(self, dataset_builder, mock_dl_manager, keep_uncompressed) -> Optional[bool]:
dl_cache_dir = os.path.join(self._cache_dir or config.HF_DATASETS_CACHE, "downloads")
dl_cache_dir = (
os.path.join(self._cache_dir, config.DOWNLOADED_DATASETS_DIR)
if self._cache_dir
else config.DOWNLOADED_DATASETS_PATH
)
download_config = DownloadConfig(cache_dir=dl_cache_dir)
dl_manager = DummyDataGeneratorDownloadManager(
dataset_name=self._dataset_name, mock_download_manager=mock_dl_manager, download_config=download_config
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/commands/run_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def run(self):
download_mode=GenerateMode.REUSE_CACHE_IF_EXISTS
if not self._force_redownload
else GenerateMode.FORCE_REDOWNLOAD,
download_config=DownloadConfig(cache_dir=os.path.join(config.HF_DATASETS_CACHE, "downloads")),
download_config=DownloadConfig(cache_dir=config.DOWNLOADED_DATASETS_PATH),
save_infos=self._save_infos,
ignore_verifications=self._ignore_verifications,
try_from_hf_gcs=False,
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/commands/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,12 @@ def get_builders() -> Generator[DatasetBuilder, None, None]:
copyfile(dataset_infos_path, user_dataset_infos_path)
print("Dataset Infos file saved at {}".format(user_dataset_infos_path))

# If clear_cache=True, the download forlder and the dataset builder cache directory are deleted
# If clear_cache=True, the download folder and the dataset builder cache directory are deleted
if self._clear_cache:
if os.path.isdir(builder._cache_dir):
logger.warning(f"Clearing cache at {builder._cache_dir}")
rmtree(builder._cache_dir)
download_dir = os.path.join(self._cache_dir, "downloads")
download_dir = os.path.join(self._cache_dir, datasets.config.DOWNLOADED_DATASETS_DIR)
if os.path.isdir(download_dir):
logger.warning(f"Clearing cache at {download_dir}")
rmtree(download_dir)
Expand Down
4 changes: 4 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@
DEFAULT_HF_MODULES_CACHE = os.path.join(HF_CACHE_HOME, "modules")
HF_MODULES_CACHE = Path(os.getenv("HF_MODULES_CACHE", DEFAULT_HF_MODULES_CACHE))

DOWNLOADED_DATASETS_DIR = "downloads"
DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(HF_DATASETS_CACHE, DOWNLOADED_DATASETS_DIR)
DOWNLOADED_DATASETS_PATH = Path(os.getenv("HF_DATASETS_DOWNLOADED_DATASETS_PATH", DEFAULT_DOWNLOADED_DATASETS_PATH))

# Batch size constants. For more info, see:
# https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
DEFAULT_MAX_BATCH_SIZE = 10_000
Expand Down
6 changes: 4 additions & 2 deletions src/datasets/utils/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def ship_files_with_pipeline(self, downloaded_path_or_paths, pipeline):
raise ValueError("You need to specify 'temp_location' in PipelineOptions to upload files")

def upload(local_file_path):
remote_file_path = os.path.join(remote_dir, "downloads", os.path.basename(local_file_path))
remote_file_path = os.path.join(
remote_dir, config.DOWNLOADED_DATASETS_DIR, os.path.basename(local_file_path)
)
logger.info(
"Uploading {} ({}) to {}.".format(
local_file_path, size_str(os.path.getsize(local_file_path)), remote_file_path
Expand Down Expand Up @@ -146,7 +148,7 @@ def download_custom(self, url_or_urls, custom_download):
downloaded_path(s): `str`, The downloaded paths matching the given input
url_or_urls.
"""
cache_dir = self._download_config.cache_dir or os.path.join(config.HF_DATASETS_CACHE, "downloads")
cache_dir = self._download_config.cache_dir or config.DOWNLOADED_DATASETS_PATH
max_retries = self._download_config.max_retries

def url_to_downloaded_path(url):
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def cached_path(
if download_config is None:
download_config = DownloadConfig(**download_kwargs)

cache_dir = download_config.cache_dir or os.path.join(config.HF_DATASETS_CACHE, "downloads")
cache_dir = download_config.cache_dir or config.DOWNLOADED_DATASETS_PATH
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if isinstance(url_or_filename, Path):
Expand Down
14 changes: 8 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
def set_test_cache_config(tmp_path_factory, monkeypatch):
# test_hf_cache_home = tmp_path_factory.mktemp("cache") # TODO: why a cache dir per test function does not work?
test_hf_cache_home = tmp_path_factory.getbasetemp() / "cache"
test_hf_datasets_cache = str(test_hf_cache_home / "datasets")
test_hf_metrics_cache = str(test_hf_cache_home / "metrics")
test_hf_modules_cache = str(test_hf_cache_home / "modules")
monkeypatch.setattr("datasets.config.HF_DATASETS_CACHE", test_hf_datasets_cache)
monkeypatch.setattr("datasets.config.HF_METRICS_CACHE", test_hf_metrics_cache)
monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", test_hf_modules_cache)
test_hf_datasets_cache = test_hf_cache_home / "datasets"
test_hf_metrics_cache = test_hf_cache_home / "metrics"
test_hf_modules_cache = test_hf_cache_home / "modules"
monkeypatch.setattr("datasets.config.HF_DATASETS_CACHE", str(test_hf_datasets_cache))
monkeypatch.setattr("datasets.config.HF_METRICS_CACHE", str(test_hf_metrics_cache))
monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", str(test_hf_modules_cache))
test_downloaded_datasets_path = test_hf_datasets_cache / "downloads"
monkeypatch.setattr("datasets.config.DOWNLOADED_DATASETS_PATH", str(test_downloaded_datasets_path))


FILE_CONTENT = """\
Expand Down
3 changes: 2 additions & 1 deletion tests/test_dummy_data_autogenerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tempfile import TemporaryDirectory
from unittest import TestCase

import datasets.config
from datasets.builder import GeneratorBasedBuilder
from datasets.commands.dummy_data import DummyDataGeneratorDownloadManager, MockDownloadManager
from datasets.features import Features, Value
Expand Down Expand Up @@ -73,7 +74,7 @@ class MockDownloadManagerWithCustomDatasetsScriptsDir(MockDownloadManager):
cache_dir=cache_dir,
load_existing_dummy_data=False, # dummy data don't exist yet
)
download_config = DownloadConfig(cache_dir=os.path.join(tmp_dir, "downloads"))
download_config = DownloadConfig(cache_dir=os.path.join(tmp_dir, datasets.config.DOWNLOADED_DATASETS_DIR))
dl_manager = DummyDataGeneratorDownloadManager(
dataset_name=dataset_builder.name,
mock_download_manager=mock_dl_manager,
Expand Down