Skip to content

Commit 79a1526

Browse files
Fix get_data_patterns for directories with the word data twice (#6309)
* Test get_data_patterns from directory with the word data twice * Fix get_data_patterns * Use glob_pattern_to_regex in entire xjoin * Fix test by passing base_path as posix * Use slash instead of xjoin for data files patterns * Fix slash sep
1 parent fdc29db commit 79a1526

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/datasets/data_files.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir(matched_rel_
227227
return len(hidden_directories_in_path) != len(hidden_directories_in_pattern)
228228

229229

230-
def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Dict[str, List[str]]:
230+
def _get_data_files_patterns(
231+
pattern_resolver: Callable[[str], List[str]], base_path: str = ""
232+
) -> Dict[str, List[str]]:
231233
"""
232234
Get the default pattern from a directory or repository by testing all the supported patterns.
233235
The first patterns to return a non-empty list of data files is returned.
@@ -242,7 +244,8 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Di
242244
except FileNotFoundError:
243245
continue
244246
if len(data_files) > 0:
245-
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(split_pattern))["split"] for p in data_files}
247+
pattern = base_path + ("/" if base_path else "") + split_pattern
248+
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files}
246249
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
247250
splits - set(DEFAULT_SPLITS)
248251
)
@@ -462,7 +465,7 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]
462465
"""
463466
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
464467
try:
465-
return _get_data_files_patterns(resolver)
468+
return _get_data_files_patterns(resolver, base_path=base_path)
466469
except FileNotFoundError:
467470
raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None
468471

tests/test_data_files.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_get_metadata_files_patterns,
1717
_is_inside_unrequested_special_dir,
1818
_is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir,
19+
get_data_patterns,
1920
resolve_pattern,
2021
)
2122
from datasets.fingerprint import Hasher
@@ -634,3 +635,13 @@ def resolver(pattern):
634635
patterns = _get_metadata_files_patterns(resolver)
635636
matched = [file_path for pattern in patterns for file_path in resolver(pattern)]
636637
assert sorted(matched) == sorted(metadata_files)
638+
639+
640+
def test_get_data_patterns_from_directory_with_the_word_data_twice(tmp_path):
641+
repo_dir = tmp_path / "directory-name-ending-with-the-word-data" # parent directory contains the word "data/"
642+
data_dir = repo_dir / "data"
643+
data_dir.mkdir(parents=True)
644+
data_file = data_dir / "train-00001-of-00009.parquet"
645+
data_file.touch()
646+
data_file_patterns = get_data_patterns(repo_dir.as_posix())
647+
assert data_file_patterns == {"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]}

0 commit comments

Comments
 (0)