Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_
return len(hidden_directories_in_path) != len(hidden_directories_in_pattern)


def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Dict[str, List[str]]:
def _get_data_files_patterns(
pattern_resolver: Callable[[str], List[str]], base_path: str = ""
) -> Dict[str, List[str]]:
"""
Get the default pattern from a directory or repository by testing all the supported patterns.
The first patterns to return a non-empty list of data files is returned.
Expand All @@ -242,7 +244,8 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Di
except FileNotFoundError:
continue
if len(data_files) > 0:
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(split_pattern))["split"] for p in data_files}
pattern = base_path + ("/" if base_path else "") + split_pattern
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files}
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
splits - set(DEFAULT_SPLITS)
)
Expand Down Expand Up @@ -462,7 +465,7 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]
"""
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
try:
return _get_data_files_patterns(resolver)
return _get_data_files_patterns(resolver, base_path=base_path)
except FileNotFoundError:
raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None

Expand Down
11 changes: 11 additions & 0 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_get_metadata_files_patterns,
_is_inside_unrequested_special_dir,
_is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir,
get_data_patterns,
resolve_pattern,
)
from datasets.fingerprint import Hasher
Expand Down Expand Up @@ -634,3 +635,13 @@ def resolver(pattern):
patterns = _get_metadata_files_patterns(resolver)
matched = [file_path for pattern in patterns for file_path in resolver(pattern)]
assert sorted(matched) == sorted(metadata_files)


def test_get_data_patterns_from_directory_with_the_word_data_twice(tmp_path):
repo_dir = tmp_path / "directory-name-ending-with-the-word-data" # parent directory contains the word "data/"
data_dir = repo_dir / "data"
data_dir.mkdir(parents=True)
data_file = data_dir / "train-00001-of-00009.parquet"
data_file.touch()
data_file_patterns = get_data_patterns(repo_dir.as_posix())
assert data_file_patterns == {"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]}