Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,10 @@ def _get_modification_time(module_hash):
}
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___"))
namespace_and_dataset_name = self.name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
Expand Down
7 changes: 5 additions & 2 deletions src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import datasets
import datasets.config
import datasets.data_files
from datasets.naming import filenames_for_dataset_split
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split


logger = datasets.utils.logging.get_logger(__name__)
Expand All @@ -36,7 +36,10 @@ def _find_hash_in_cache(
else:
config_id = None
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, dataset_name.replace("/", "___"))
namespace_and_dataset_name = dataset_name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
Expand Down
23 changes: 23 additions & 0 deletions tests/packaged_modules/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"
SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters"


def test_cache(text_dir: Path, tmp_path: Path):
Expand Down Expand Up @@ -133,3 +134,25 @@ def test_cache_single_config(tmp_path: Path):
hash="auto",
)
assert config_name in str(excinfo.value)


@pytest.mark.integration
def test_cache_capital_letters(tmp_path: Path):
cache_dir = tmp_path / "test_cache_capital_letters"
repo_id = SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME
dataset_name = repo_id.split("/")[-1]
ds = load_dataset(repo_id, cache_dir=str(cache_dir))
cache = Cache(cache_dir=str(cache_dir), dataset_name=dataset_name, repo_id=repo_id, version="auto", hash="auto")
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
cache = Cache(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
repo_id=repo_id,
version="auto",
hash="auto",
)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
14 changes: 14 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _generate_examples(self, filepath, **kwargs):
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA_WITH_DEFAULT = (
"hf-internal-testing/audiofolder_two_configs_in_metadata_with_default"
)
SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME = "hf-internal-testing/DatasetWithCapitalLetters"


METRIC_LOADING_SCRIPT_NAME = "__dummy_metric1__"
Expand Down Expand Up @@ -1026,6 +1027,19 @@ def test_offline_dataset_module_factory_with_script(self):
self.assertNotEqual(dataset_module_1.module_path, dataset_module_3.module_path)
self.assertIn("Using the latest cached version of the module", self._caplog.text)

@pytest.mark.integration
def test_offline_dataset_module_factory_with_capital_letters_in_name(self):
repo_id = SAMPLE_DATASET_CAPITAL_LETTERS_IN_NAME
builder = load_dataset_builder(repo_id, cache_dir=self.cache_dir)
builder.download_and_prepare()
for offline_simulation_mode in list(OfflineSimulationMode):
with offline(offline_simulation_mode):
self._caplog.clear()
# allow provide the repo id without an explicit path to remote or local actual file
dataset_module = datasets.load.dataset_module_factory(repo_id, cache_dir=self.cache_dir)
self.assertEqual(dataset_module.module_path, "datasets.packaged_modules.cache.cache")
self.assertIn("Using the latest cached version of the dataset", self._caplog.text)

def test_load_dataset_from_hub(self):
with self.assertRaises(DatasetNotFoundError) as context:
datasets.load_dataset("_dummy")
Expand Down