|
1 | 1 | import os |
2 | | -from itertools import chain |
| 2 | +from contextlib import contextmanager |
3 | 3 | from pathlib import Path, PurePath |
| 4 | +from typing import List |
4 | 5 | from unittest.mock import patch |
5 | 6 |
|
6 | 7 | import fsspec |
7 | 8 | import pytest |
| 9 | +from fsspec.spec import AbstractBufferedFile, AbstractFileSystem |
8 | 10 | from huggingface_hub.hf_api import DatasetInfo |
9 | 11 |
|
10 | 12 | from datasets.data_files import ( |
@@ -491,6 +493,69 @@ def test_DataFilesDict_from_hf_local_or_remote_hashing(text_file): |
491 | 493 | assert Hasher.hash(data_files1) != Hasher.hash(data_files2) |
492 | 494 |
|
493 | 495 |
|
| 496 | +@contextmanager |
| 497 | +def mock_fs(file_paths: List[str]): |
| 498 | + """context manager to set up a mock:// filesystem in sfspec containing the provided files""" |
| 499 | + |
| 500 | + dir_paths = {file_path.rsplit("/")[0] for file_path in file_paths if "/" in file_path} |
| 501 | + fs_contents = [{"name": dir_path, "type": "directory"} for dir_path in dir_paths] + [ |
| 502 | + {"name": file_path, "type": "file", "size": 10} for file_path in file_paths |
| 503 | + ] |
| 504 | + |
| 505 | + class DummyTestFS(AbstractFileSystem): |
| 506 | + protocol = "mock" |
| 507 | + _file_class = AbstractBufferedFile |
| 508 | + _fs_contents = fs_contents |
| 509 | + |
| 510 | + def __getitem__(self, name): |
| 511 | + for item in self._fs_contents: |
| 512 | + if item["name"] == name: |
| 513 | + return item |
| 514 | + raise IndexError(f"{name} not found!") |
| 515 | + |
| 516 | + def ls(self, path, detail=True, refresh=True, **kwargs): |
| 517 | + if kwargs.pop("strip_proto", True): |
| 518 | + path = self._strip_protocol(path) |
| 519 | + |
| 520 | + files = not refresh and self._ls_from_cache(path) |
| 521 | + if not files: |
| 522 | + files = [file for file in self._fs_contents if path == self._parent(file["name"])] |
| 523 | + files.sort(key=lambda file: file["name"]) |
| 524 | + self.dircache[path.rstrip("/")] = files |
| 525 | + |
| 526 | + if detail: |
| 527 | + return files |
| 528 | + return [file["name"] for file in files] |
| 529 | + |
| 530 | + @classmethod |
| 531 | + def get_test_paths(cls, start_with=""): |
| 532 | + """Helper to return directory and file paths with no details""" |
| 533 | + all = [file["name"] for file in cls._fs_contents if file["name"].startswith(start_with)] |
| 534 | + return all |
| 535 | + |
| 536 | + def _open( |
| 537 | + self, |
| 538 | + path, |
| 539 | + mode="rb", |
| 540 | + block_size=None, |
| 541 | + autocommit=True, |
| 542 | + cache_options=None, |
| 543 | + **kwargs, |
| 544 | + ): |
| 545 | + return self._file_class( |
| 546 | + self, |
| 547 | + path, |
| 548 | + mode, |
| 549 | + block_size, |
| 550 | + autocommit, |
| 551 | + cache_options=cache_options, |
| 552 | + **kwargs, |
| 553 | + ) |
| 554 | + |
| 555 | + with patch.dict(fsspec.registry.target, {"mock": DummyTestFS}): |
| 556 | + yield DummyTestFS() |
| 557 | + |
| 558 | + |
494 | 559 | @pytest.mark.parametrize( |
495 | 560 | "data_file_per_split", |
496 | 561 | [ |
@@ -541,25 +606,25 @@ def test_DataFilesDict_from_hf_local_or_remote_hashing(text_file): |
541 | 606 | {"validation": "dev/dataset.txt"}, |
542 | 607 | # With other extensions |
543 | 608 | {"train": "train.parquet", "test": "test.parquet", "validation": "valid.parquet"}, |
| 609 | + # With "dev" or "eval" without separators |
| 610 | + {"train": "developers_list.txt"}, |
| 611 | + {"train": "data/seqeval_results.txt"}, |
544 | 612 | ], |
545 | 613 | ) |
546 | 614 | def test_get_data_files_patterns(data_file_per_split): |
547 | 615 | data_file_per_split = {k: v if isinstance(v, list) else [v] for k, v in data_file_per_split.items()} |
548 | | - |
549 | | - def resolver(pattern): |
550 | | - return [PurePath(path) for path in chain(*data_file_per_split.values()) if PurePath(path).match(pattern)] |
551 | | - |
552 | | - patterns_per_split = _get_data_files_patterns(resolver) |
553 | | - assert sorted(patterns_per_split.keys()) == sorted(data_file_per_split.keys()) |
554 | | - for split, patterns in patterns_per_split.items(): |
555 | | - matched = [ |
556 | | - path |
557 | | - for path in chain(*data_file_per_split.values()) |
558 | | - for pattern in patterns |
559 | | - if PurePath(path).match(pattern) |
560 | | - ] |
561 | | - assert len(matched) == len(data_file_per_split[split]) |
562 | | - assert matched == data_file_per_split[split] |
| 616 | + with mock_fs( |
| 617 | + [file_path for split_file_paths in data_file_per_split.values() for file_path in split_file_paths] |
| 618 | + ) as fs: |
| 619 | + |
| 620 | + def resolver(pattern): |
| 621 | + return [PurePath(file_path) for file_path in fs.glob(pattern) if fs.isfile(file_path)] |
| 622 | + |
| 623 | + patterns_per_split = _get_data_files_patterns(resolver) |
| 624 | + assert sorted(patterns_per_split.keys()) == sorted(data_file_per_split.keys()) |
| 625 | + for split, patterns in patterns_per_split.items(): |
| 626 | + matched = [str(file_path) for pattern in patterns for file_path in resolver(pattern)] |
| 627 | + assert matched == data_file_per_split[split] |
563 | 628 |
|
564 | 629 |
|
565 | 630 | @pytest.mark.parametrize( |
|
0 commit comments