Skip to content

Commit 4c46349

Browse files
committed
more tests
1 parent 84d8397 commit 4c46349

File tree

4 files changed

+150
-58
lines changed

4 files changed

+150
-58
lines changed

src/datasets/builder.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def incomplete_dir(dirname):
673673
self._fs.rm(dirname, recursive=True)
674674
if is_local:
675675
# LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory
676-
os.rename(tmp_dir, dirname)
676+
os.rename(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname))
677677
else:
678678
self._fs.mv(tmp_dir, dirname, recursive=True)
679679
finally:
@@ -691,8 +691,7 @@ def incomplete_dir(dirname):
691691
f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..."
692692
)
693693
else:
694-
_protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1]
695-
_dest = self._cache_dir if is_local else _protocol + "://" + self._cache_dir
694+
_dest = self._fs._strip_protocol(self._cache_dir) if is_local else self._cache_dir
696695
print(
697696
f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..."
698697
)
@@ -834,7 +833,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr
834833
self.info.download_size = dl_manager.downloaded_size
835834

836835
def download_post_processing_resources(self, dl_manager):
837-
for split in self.info.splits:
836+
for split in self.info.splits or []:
838837
for resource_name, resource_file_name in self._post_processing_resources(split).items():
839838
if not not is_remote_filesystem(self._fs):
840839
raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}")
@@ -1234,9 +1233,9 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None
12341233
split_info = split_generator.split_info
12351234

12361235
file_format = file_format or "arrow"
1237-
fname = f"{self.name}-{split_generator.name}.{file_format}"
1238-
protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1]
1239-
fpath = protocol + "://" + path_join(self._cache_dir, fname)
1236+
suffix = "-00000-of-00001" if file_format == "parquet" else ""
1237+
fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
1238+
fpath = path_join(self._cache_dir, fname)
12401239

12411240
generator = self._generate_examples(**split_generator.gen_kwargs)
12421241

@@ -1315,9 +1314,9 @@ def _prepare_split(self, split_generator, file_format=None):
13151314
path_join = os.path.join if is_local else posixpath.join
13161315

13171316
file_format = file_format or "arrow"
1318-
fname = f"{self.name}-{split_generator.name}.{file_format}"
1319-
protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1]
1320-
fpath = protocol + "://" + path_join(self._cache_dir, fname)
1317+
suffix = "-00000-of-00001" if file_format == "parquet" else ""
1318+
fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
1319+
fpath = path_join(self._cache_dir, fname)
13211320

13221321
generator = self._generate_tables(**split_generator.gen_kwargs)
13231322
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter

tests/fsspec_fixtures.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class MockFileSystem(AbstractFileSystem):
1212
def __init__(self, *args, local_root_dir, **kwargs):
1313
super().__init__()
1414
self._fs = LocalFileSystem(*args, **kwargs)
15-
self.local_root_dir = Path(local_root_dir).as_posix()
15+
self.local_root_dir = Path(local_root_dir).resolve().as_posix() + "/"
1616

1717
def mkdir(self, path, *args, **kwargs):
1818
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
@@ -26,45 +26,28 @@ def rmdir(self, path):
2626
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
2727
return self._fs.rmdir(path)
2828

29-
def ls(self, path, *args, **kwargs):
29+
def ls(self, path, detail=True, *args, **kwargs):
3030
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
31-
return self._fs.ls(path, *args, **kwargs)
32-
33-
def glob(self, path, *args, **kwargs):
34-
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
35-
return self._fs.glob(path, *args, **kwargs)
31+
out = self._fs.ls(path, detail=detail, *args, **kwargs)
32+
if detail:
33+
return [{**info, "name": info["name"][len(self.local_root_dir) :]} for info in out]
34+
else:
35+
return [name[len(self.local_root_dir) :] for name in out]
3636

3737
def info(self, path, *args, **kwargs):
3838
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
39-
return self._fs.info(path, *args, **kwargs)
40-
41-
def lexists(self, path, *args, **kwargs):
42-
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
43-
return self._fs.lexists(path, *args, **kwargs)
39+
out = dict(self._fs.info(path, *args, **kwargs))
40+
out["name"] = out["name"][len(self.local_root_dir) :]
41+
return out
4442

4543
def cp_file(self, path1, path2, *args, **kwargs):
4644
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
4745
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
4846
return self._fs.cp_file(path1, path2, *args, **kwargs)
4947

50-
def get_file(self, path1, path2, *args, **kwargs):
51-
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
52-
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
53-
return self._fs.get_file(path1, path2, *args, **kwargs)
54-
55-
def put_file(self, path1, path2, *args, **kwargs):
56-
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
57-
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
58-
return self._fs.put_file(path1, path2, *args, **kwargs)
59-
60-
def mv_file(self, path1, path2, *args, **kwargs):
61-
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
62-
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
63-
return self._fs.mv_file(path1, path2, *args, **kwargs)
64-
65-
def rm_file(self, path):
48+
def rm_file(self, path, *args, **kwargs):
6649
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
67-
return self._fs.rm_file(path)
50+
return self._fs.rm_file(path, *args, **kwargs)
6851

6952
def rm(self, path, *args, **kwargs):
7053
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
@@ -74,14 +57,6 @@ def _open(self, path, *args, **kwargs):
7457
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
7558
return self._fs._open(path, *args, **kwargs)
7659

77-
def open(self, path, *args, **kwargs):
78-
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
79-
return self._fs.open(path, *args, **kwargs)
80-
81-
def touch(self, path, *args, **kwargs):
82-
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
83-
return self._fs.touch(path, *args, **kwargs)
84-
8560
def created(self, path):
8661
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
8762
return self._fs.created(path)
@@ -90,21 +65,13 @@ def modified(self, path):
9065
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
9166
return self._fs.modified(path)
9267

93-
@classmethod
94-
def _parent(cls, path):
95-
return LocalFileSystem._parent(path)
96-
9768
@classmethod
9869
def _strip_protocol(cls, path):
9970
path = stringify_path(path)
10071
if path.startswith("mock://"):
10172
path = path[7:]
10273
return path
10374

104-
def chmod(self, path, *args, **kwargs):
105-
path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
106-
return self._fs.mkdir(path, *args, **kwargs)
107-
10875

10976
@pytest.fixture
11077
def mock_fsspec(monkeypatch):

tests/test_arrow_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import numpy as np
88
import pyarrow as pa
9-
import pytest
109
import pyarrow.parquet as pq
10+
import pytest
1111

1212
from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence
1313
from datasets.features import Array2D, ClassLabel, Features, Image, Value

tests/test_builder.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,24 @@
88
from unittest.mock import patch
99

1010
import numpy as np
11+
import pyarrow as pa
12+
import pyarrow.parquet as pq
1113
import pytest
1214
from multiprocess.pool import Pool
1315

1416
from datasets.arrow_dataset import Dataset
1517
from datasets.arrow_writer import ArrowWriter
16-
from datasets.builder import BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
18+
from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
1719
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
1820
from datasets.download.download_manager import DownloadMode
1921
from datasets.features import Features, Value
2022
from datasets.info import DatasetInfo, PostProcessedInfo
2123
from datasets.iterable_dataset import IterableDataset
2224
from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo
2325
from datasets.streaming import xjoin
26+
from datasets.utils.file_utils import is_local_path
2427

25-
from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_faiss
28+
from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss
2629

2730

2831
class DummyBuilder(DatasetBuilder):
@@ -57,6 +60,35 @@ def _generate_examples(self):
5760
yield i, {"text": "foo"}
5861

5962

63+
class DummyArrowBasedBuilder(ArrowBasedBuilder):
64+
def _info(self):
65+
return DatasetInfo(features=Features({"text": Value("string")}))
66+
67+
def _split_generators(self, dl_manager):
68+
return [SplitGenerator(name=Split.TRAIN)]
69+
70+
def _generate_tables(self):
71+
for i in range(10):
72+
yield i, pa.table({"text": ["foo"] * 10})
73+
74+
75+
class DummyBeamBasedBuilder(BeamBasedBuilder):
76+
def _info(self):
77+
return DatasetInfo(features=Features({"text": Value("string")}))
78+
79+
def _split_generators(self, dl_manager):
80+
return [SplitGenerator(name=Split.TRAIN)]
81+
82+
def _build_pcollection(self, pipeline):
83+
import apache_beam as beam
84+
85+
def _process(item):
86+
for i in range(10):
87+
yield f"{i}_{item}", {"text": "foo"}
88+
89+
return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process)
90+
91+
6092
class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder):
6193
def _info(self):
6294
return DatasetInfo(features=Features({"id": Value("int8")}))
@@ -690,6 +722,41 @@ def test_cache_dir_for_data_dir(self):
690722
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
691723

692724

725+
def test_arrow_based_download_and_prepare(tmp_path):
726+
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
727+
builder.download_and_prepare()
728+
assert os.path.exists(
729+
os.path.join(
730+
tmp_path,
731+
builder.name,
732+
"default",
733+
"0.0.0",
734+
f"{builder.name}-train.arrow",
735+
)
736+
)
737+
assert builder.info.features, Features({"text": Value("string")})
738+
assert builder.info.splits["train"].num_examples, 100
739+
assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json"))
740+
741+
742+
@require_beam
743+
def test_beam_based_download_and_prepare(tmp_path):
744+
builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
745+
builder.download_and_prepare()
746+
assert os.path.exists(
747+
os.path.join(
748+
tmp_path,
749+
builder.name,
750+
"default",
751+
"0.0.0",
752+
f"{builder.name}-train.arrow",
753+
)
754+
)
755+
assert builder.info.features, Features({"text": Value("string")})
756+
assert builder.info.splits["train"].num_examples, 100
757+
assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json"))
758+
759+
693760
@pytest.mark.parametrize(
694761
"split, expected_dataset_class, expected_dataset_length",
695762
[
@@ -846,3 +913,62 @@ def test_builder_config_version(builder_class, kwargs, tmp_path):
846913
cache_dir = str(tmp_path)
847914
builder = builder_class(cache_dir=cache_dir, **kwargs)
848915
assert builder.config.version == "2.0.0"
916+
917+
918+
def test_builder_with_filesystem(mockfs):
919+
builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options)
920+
assert builder.cache_dir.startswith("mock://")
921+
assert is_local_path(builder._cache_downloaded_dir)
922+
assert isinstance(builder._fs, type(mockfs))
923+
assert builder._fs.storage_options == mockfs.storage_options
924+
925+
926+
def test_builder_with_filesystem_download_and_prepare(mockfs):
927+
builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options)
928+
builder.download_and_prepare()
929+
assert mockfs.exists(f"{builder.name}/default/0.0.0/dataset_info.json")
930+
assert mockfs.exists(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow")
931+
assert not mockfs.exists(f"{builder.name}/default/0.0.0.incomplete")
932+
933+
934+
def test_builder_with_filesystem_download_and_prepare_reload(mockfs, caplog):
935+
builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options)
936+
mockfs.makedirs(f"{builder.name}/default/0.0.0")
937+
DatasetInfo().write_to_directory(f"{builder.name}/default/0.0.0", fs=mockfs)
938+
mockfs.touch(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow")
939+
caplog.clear()
940+
builder.download_and_prepare()
941+
assert "Found cached dataset" in caplog.text
942+
943+
944+
def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path):
945+
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
946+
builder.download_and_prepare(file_format="parquet")
947+
assert builder.info.splits["train"].num_examples, 100
948+
parquet_path = os.path.join(
949+
tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet"
950+
)
951+
assert os.path.exists(parquet_path)
952+
assert pq.ParquetFile(parquet_path) is not None
953+
954+
955+
def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path):
956+
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
957+
builder.download_and_prepare(file_format="parquet")
958+
assert builder.info.splits["train"].num_examples, 100
959+
parquet_path = os.path.join(
960+
tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet"
961+
)
962+
assert os.path.exists(parquet_path)
963+
assert pq.ParquetFile(parquet_path) is not None
964+
965+
966+
def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path):
967+
builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
968+
builder.download_and_prepare(file_format="parquet")
969+
assert builder.info.splits["train"].num_examples, 100
970+
parquet_path = os.path.join(
971+
tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet"
972+
)
973+
assert os.path.exists(parquet_path)
974+
assert pq.ParquetFile(parquet_path) is not None

0 commit comments

Comments
 (0)