Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,10 @@ def _get_data_files_patterns(
except FileNotFoundError:
continue
if len(data_files) > 0:
pattern = base_path + ("/" if base_path else "") + split_pattern
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files}
splits: Set[str] = {
string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))["split"]
for p in data_files
}
Copy link
Member

@albertvillanova albertvillanova Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are matching just in the basename, then what is the point of having 2 kinds of patterns?

  • ALL_SPLIT_PATTERNS: data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*
  • ALL_DEFAULT_PATTERNS: **/*[{sep}/]{keyword}[{sep}/]**

Maybe I'm missing something, but why do we need the former? I would naively say the latter contains the former.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only ALL_SPLIT_PATTERNS are parsed to infer custom split names.

While the second only detects train/valid/test

Copy link
Member

@albertvillanova albertvillanova Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and what is the point of the directory data/ in ALL_SPLIT_PATTERNS if we only match the basename?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for old push_to_hub to work: they push custom splits using this pattern in the data directory.
New push_to_hub have some YAML to specify the pattern to use, so get_data_patterns isn't called

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, all clear now. Thanks.

sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
splits - set(DEFAULT_SPLITS)
)
Expand Down
33 changes: 26 additions & 7 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,10 @@ def mock_fs(file_paths: List[str]):
["data", "data/train.txt", "data.test.txt"]
```
"""

dir_paths = {file_path.rsplit("/", 1)[0] for file_path in file_paths if "/" in file_path}
file_paths = [file_path.split("://")[-1] for file_path in file_paths]
dir_paths = {
"/".join(file_path.split("/")[: i + 1]) for file_path in file_paths for i in range(file_path.count("/"))
}
fs_contents = [{"name": dir_path, "type": "directory"} for dir_path in dir_paths] + [
{"name": file_path, "type": "file", "size": 10} for file_path in file_paths
]
Expand All @@ -529,6 +531,7 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
return DummyTestFS


@pytest.mark.parametrize("base_path", ["", "mock://", "my_dir"])
@pytest.mark.parametrize(
"data_file_per_split",
[
Expand Down Expand Up @@ -598,20 +601,36 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
{"test": "test00001.txt"},
],
)
def test_get_data_files_patterns(data_file_per_split):
def test_get_data_files_patterns(base_path, data_file_per_split):
data_file_per_split = {k: v if isinstance(v, list) else [v] for k, v in data_file_per_split.items()}
file_paths = [file_path for split_file_paths in data_file_per_split.values() for file_path in split_file_paths]
data_file_per_split = {
split: [
base_path + ("/" if base_path and base_path[-1] != "/" else "") + file_path
for file_path in data_file_per_split[split]
]
for split in data_file_per_split
}
file_paths = sum(data_file_per_split.values(), [])
DummyTestFS = mock_fs(file_paths)
fs = DummyTestFS()

def resolver(pattern):
return [file_path for file_path in fs.glob(pattern) if fs.isfile(file_path)]
pattern = base_path + ("/" if base_path and base_path[-1] != "/" else "") + pattern
return [
file_path[len(fs._strip_protocol(base_path)) :].lstrip("/")
for file_path in fs.glob(pattern)
if fs.isfile(file_path)
]

patterns_per_split = _get_data_files_patterns(resolver)
patterns_per_split = _get_data_files_patterns(resolver, base_path=base_path)
assert list(patterns_per_split.keys()) == list(data_file_per_split.keys()) # Test split order with list()
for split, patterns in patterns_per_split.items():
matched = [file_path for pattern in patterns for file_path in resolver(pattern)]
assert matched == data_file_per_split[split]
expected = [
fs._strip_protocol(file_path)[len(fs._strip_protocol(base_path)) :].lstrip("/")
for file_path in data_file_per_split[split]
]
assert matched == expected


@pytest.mark.parametrize(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
load_dataset_builder,
)
from datasets.config import METADATA_CONFIGS_FIELD
from datasets.data_files import get_data_patterns
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
FolderBasedBuilder,
FolderBasedBuilderConfig,
Expand Down Expand Up @@ -884,3 +885,22 @@ def test_load_dataset_with_metadata_file(self, temporary_repo, text_file_with_me
generator = builder._generate_examples(**gen_kwargs)
result = [example for _, example in generator]
assert len(result) == 1

def test_get_data_patterns(self, temporary_repo, tmp_path):
repo_dir = tmp_path / "test_get_data_patterns"
data_dir = repo_dir / "data"
data_dir.mkdir(parents=True)
data_file = data_dir / "train-00001-of-00009.parquet"
data_file.touch()
with temporary_repo() as repo_id:
self._api.create_repo(repo_id, token=self._token, repo_type="dataset")
self._api.upload_folder(
folder_path=str(repo_dir),
repo_id=repo_id,
repo_type="dataset",
token=self._token,
)
data_file_patterns = get_data_patterns(f"hf://datasets/{repo_id}")
assert data_file_patterns == {
"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]
}