Skip to content

Commit 02ecc84

Browse files
ZachNagengastalbertvillanovalhoestq
authored
Fix regex get_data_files formatting for base paths (#6322)
* Fix regex from formatting url base_path * Test test_get_data_patterns from Hub * simply match basename instead * more tests * minor * remove comment --------- Co-authored-by: Albert Villanova del Moral <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent d82f3c2 commit 02ecc84

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

src/datasets/data_files.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,10 @@ def _get_data_files_patterns(
244244
except FileNotFoundError:
245245
continue
246246
if len(data_files) > 0:
247-
pattern = base_path + ("/" if base_path else "") + split_pattern
248-
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files}
247+
splits: Set[str] = {
248+
string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))["split"]
249+
for p in data_files
250+
}
249251
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
250252
splits - set(DEFAULT_SPLITS)
251253
)

tests/test_data_files.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,10 @@ def mock_fs(file_paths: List[str]):
502502
["data", "data/train.txt", "data.test.txt"]
503503
```
504504
"""
505-
506-
dir_paths = {file_path.rsplit("/", 1)[0] for file_path in file_paths if "/" in file_path}
505+
file_paths = [file_path.split("://")[-1] for file_path in file_paths]
506+
dir_paths = {
507+
"/".join(file_path.split("/")[: i + 1]) for file_path in file_paths for i in range(file_path.count("/"))
508+
}
507509
fs_contents = [{"name": dir_path, "type": "directory"} for dir_path in dir_paths] + [
508510
{"name": file_path, "type": "file", "size": 10} for file_path in file_paths
509511
]
@@ -529,6 +531,7 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
529531
return DummyTestFS
530532

531533

534+
@pytest.mark.parametrize("base_path", ["", "mock://", "my_dir"])
532535
@pytest.mark.parametrize(
533536
"data_file_per_split",
534537
[
@@ -598,20 +601,36 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
598601
{"test": "test00001.txt"},
599602
],
600603
)
601-
def test_get_data_files_patterns(data_file_per_split):
604+
def test_get_data_files_patterns(base_path, data_file_per_split):
602605
data_file_per_split = {k: v if isinstance(v, list) else [v] for k, v in data_file_per_split.items()}
603-
file_paths = [file_path for split_file_paths in data_file_per_split.values() for file_path in split_file_paths]
606+
data_file_per_split = {
607+
split: [
608+
base_path + ("/" if base_path and base_path[-1] != "/" else "") + file_path
609+
for file_path in data_file_per_split[split]
610+
]
611+
for split in data_file_per_split
612+
}
613+
file_paths = sum(data_file_per_split.values(), [])
604614
DummyTestFS = mock_fs(file_paths)
605615
fs = DummyTestFS()
606616

607617
def resolver(pattern):
608-
return [file_path for file_path in fs.glob(pattern) if fs.isfile(file_path)]
618+
pattern = base_path + ("/" if base_path and base_path[-1] != "/" else "") + pattern
619+
return [
620+
file_path[len(fs._strip_protocol(base_path)) :].lstrip("/")
621+
for file_path in fs.glob(pattern)
622+
if fs.isfile(file_path)
623+
]
609624

610-
patterns_per_split = _get_data_files_patterns(resolver)
625+
patterns_per_split = _get_data_files_patterns(resolver, base_path=base_path)
611626
assert list(patterns_per_split.keys()) == list(data_file_per_split.keys()) # Test split order with list()
612627
for split, patterns in patterns_per_split.items():
613628
matched = [file_path for pattern in patterns for file_path in resolver(pattern)]
614-
assert matched == data_file_per_split[split]
629+
expected = [
630+
fs._strip_protocol(file_path)[len(fs._strip_protocol(base_path)) :].lstrip("/")
631+
for file_path in data_file_per_split[split]
632+
]
633+
assert matched == expected
615634

616635

617636
@pytest.mark.parametrize(

tests/test_upstream_hub.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
load_dataset_builder,
2828
)
2929
from datasets.config import METADATA_CONFIGS_FIELD
30+
from datasets.data_files import get_data_patterns
3031
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
3132
FolderBasedBuilder,
3233
FolderBasedBuilderConfig,
@@ -884,3 +885,22 @@ def test_load_dataset_with_metadata_file(self, temporary_repo, text_file_with_me
884885
generator = builder._generate_examples(**gen_kwargs)
885886
result = [example for _, example in generator]
886887
assert len(result) == 1
888+
889+
def test_get_data_patterns(self, temporary_repo, tmp_path):
890+
repo_dir = tmp_path / "test_get_data_patterns"
891+
data_dir = repo_dir / "data"
892+
data_dir.mkdir(parents=True)
893+
data_file = data_dir / "train-00001-of-00009.parquet"
894+
data_file.touch()
895+
with temporary_repo() as repo_id:
896+
self._api.create_repo(repo_id, token=self._token, repo_type="dataset")
897+
self._api.upload_folder(
898+
folder_path=str(repo_dir),
899+
repo_id=repo_id,
900+
repo_type="dataset",
901+
token=self._token,
902+
)
903+
data_file_patterns = get_data_patterns(f"hf://datasets/{repo_id}")
904+
assert data_file_patterns == {
905+
"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]
906+
}

0 commit comments

Comments
 (0)