Skip to content

Commit 3d066ad

Browse files
committed
add tests
1 parent c52f40f commit 3d066ad

File tree

5 files changed

+188
-24
lines changed

5 files changed

+188
-24
lines changed

src/datasets/info.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -313,23 +313,24 @@ def copy(self) -> "DatasetInfo":
313313

314314
def _to_yaml_dict(self) -> dict:
315315
yaml_dict = {}
316-
for field in dataclasses.fields(self):
317-
if field.name in self._INCLUDED_INFO_IN_YAML:
318-
value = getattr(self, field.name)
316+
dataset_info_dict = asdict(self)
317+
for key in dataset_info_dict:
318+
if key in self._INCLUDED_INFO_IN_YAML:
319+
value = getattr(self, key)
319320
if hasattr(value, "_to_yaml_list"): # Features, SplitDict
320-
yaml_dict[field.name] = value._to_yaml_list()
321+
yaml_dict[key] = value._to_yaml_list()
321322
elif hasattr(value, "_to_yaml_string"): # Version
322-
yaml_dict[field.name] = value._to_yaml_string()
323+
yaml_dict[key] = value._to_yaml_string()
323324
else:
324-
yaml_dict[field.name] = value
325+
yaml_dict[key] = value
325326
return yaml_dict
326327

327328
@classmethod
328329
def _from_yaml_dict(cls, yaml_data: dict) -> "DatasetInfo":
329330
yaml_data = copy.deepcopy(yaml_data)
330-
if "features" in yaml_data:
331+
if yaml_data.get("features") is not None:
331332
yaml_data["features"] = Features._from_yaml_list(yaml_data["features"])
332-
if "splits" in yaml_data:
333+
if yaml_data.get("splits") is not None:
333334
yaml_data["splits"] = SplitDict._from_yaml_list(yaml_data["splits"])
334335
field_names = {f.name for f in dataclasses.fields(cls)}
335336
return cls(**{k: v for k, v in yaml_data.items() if k in field_names})
@@ -346,11 +347,10 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa
346347
if os.path.exists(dataset_infos_path):
347348
# for backward compatibility, let's update the JSON file if it exists
348349
with open(dataset_infos_path, "w", encoding="utf-8") as f:
349-
json.dump(
350-
{config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items()},
351-
f,
352-
indent=4 if pretty_print else None,
353-
)
350+
dataset_infos_dict = {
351+
config_name: asdict(dset_info) for config_name, dset_info in total_dataset_infos.items()
352+
}
353+
json.dump(dataset_infos_dict, f, indent=4 if pretty_print else None)
354354
# Dump the infos in the YAML part of the README.md file
355355
if os.path.exists(dataset_readme_path):
356356
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path))
@@ -365,6 +365,9 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa
365365
dataset_metadata["dataset_infos"] = dataset_metadata["dataset_infos"][0]
366366
# no need to include the configuration name when there's only one configuration
367367
dataset_metadata["dataset_infos"].pop("config_name", None)
368+
else:
369+
for config_name, dataset_info_yaml_dict in zip(total_dataset_infos, dataset_metadata["dataset_infos"]):
370+
dataset_info_yaml_dict["config_name"] = config_name
368371
dataset_metadata.to_readme(Path(dataset_readme_path))
369372

370373
@classmethod
@@ -383,7 +386,7 @@ def from_directory(cls, dataset_infos_dir):
383386
# Load the info from the YAML part of README.md
384387
if os.path.exists(os.path.join(dataset_infos_dir, "README.md")):
385388
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_infos_dir) / "README.md")
386-
if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)) and dataset_metadata["dataset_infos"]:
389+
if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)):
387390
if isinstance(dataset_metadata["dataset_infos"], list):
388391
dataset_infos_dict = {
389392
dataset_info_yaml_dict.get("config_name", "default"): DatasetInfo._from_yaml_dict(
@@ -392,11 +395,10 @@ def from_directory(cls, dataset_infos_dir):
392395
for dataset_info_yaml_dict in dataset_metadata["dataset_infos"]
393396
}
394397
else:
395-
dataset_infos_dict = {
396-
dataset_metadata["dataset_infos"].get("config_name", "default"): DatasetInfo._from_yaml_dict(
397-
dataset_metadata["dataset_infos"]
398-
)
399-
}
398+
dataset_info = DatasetInfo._from_yaml_dict(dataset_metadata["dataset_infos"])
399+
dataset_info.config_name = dataset_metadata["dataset_infos"].get("config_name", "default")
400+
dataset_infos_dict = {dataset_info.config_name: dataset_info}
401+
400402
return cls(**dataset_infos_dict)
401403

402404

src/datasets/splits.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
import abc
2020
import collections
21+
import copy
2122
import dataclasses
2223
import re
23-
from dataclasses import InitVar, dataclass
24+
from dataclasses import dataclass
2425
from typing import Dict, List, Optional, Union
2526

2627
from .arrow_reader import FileInstructions, make_file_instructions
@@ -33,7 +34,14 @@ class SplitInfo:
3334
name: str = ""
3435
num_bytes: int = 0
3536
num_examples: int = 0
36-
dataset_name: InitVar[Optional[str]] = None # Pseudo-field: ignored by asdict/fields when converting to/from dict
37+
38+
# Deprecated
39+
# For backward compatibility, this field needs to always be included in files like
40+
# dataset_infos.json and dataset_info.json files
41+
# To do so, we always include it in the output of datasets.utils.py_utils.asdict(split_info)
42+
dataset_name: Optional[str] = dataclasses.field(
43+
default=None, metadata={"include_in_asdict_even_if_is_default": True}
44+
)
3745

3846
@property
3947
def file_instructions(self):
@@ -560,13 +568,22 @@ def from_split_dict(cls, split_infos: Union[List, Dict], dataset_name: Optional[
560568
def to_split_dict(self):
561569
"""Returns a list of SplitInfo protos that we have."""
562570
# Return the SplitInfo, sorted by name
563-
return sorted((s for s in self.values()), key=lambda s: s.name)
571+
out = []
572+
for split_name, split_info in sorted(self.items()):
573+
split_info = copy.deepcopy(split_info)
574+
split_info.name = split_name
575+
out.append(split_info)
576+
return out
564577

565578
def copy(self):
566579
return SplitDict.from_split_dict(self.to_split_dict(), self.dataset_name)
567580

568581
def _to_yaml_list(self) -> list:
569-
return [asdict(s) for s in self.to_split_dict()]
582+
out = [asdict(s) for s in self.to_split_dict()]
583+
# we don't need the dataset_name attribute that is deprecated
584+
for split_info_dict in out:
585+
split_info_dict.pop("dataset_name", None)
586+
return out
570587

571588
@classmethod
572589
def _from_yaml_list(cls, yaml_data: list) -> "SplitDict":

src/datasets/utils/py_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _asdict_inner(obj):
167167
result = {}
168168
for f in fields(obj):
169169
value = _asdict_inner(getattr(obj, f.name))
170-
if value != f.default or not f.init:
170+
if not f.init or value != f.default or f.metadata.get("include_in_asdict_even_if_is_default", False):
171171
result[f.name] = value
172172
return result
173173
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):

tests/test_info.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
3+
import pytest
4+
import yaml
5+
6+
from datasets.features.features import Features, Value
7+
from datasets.info import DatasetInfo, DatasetInfosDict
8+
9+
10+
@pytest.mark.parametrize(
11+
"dataset_info",
12+
[
13+
DatasetInfo(),
14+
DatasetInfo(
15+
description="foo",
16+
features=Features({"a": Value("int32")}),
17+
builder_name="builder",
18+
config_name="config",
19+
version="1.0.0",
20+
splits=[{"name": "train"}],
21+
download_size=42,
22+
),
23+
],
24+
)
25+
def test_dataset_info_dump_and_reload(tmp_path, dataset_info: DatasetInfo):
26+
tmp_path = str(tmp_path)
27+
dataset_info.write_to_directory(tmp_path)
28+
reloaded = DatasetInfo.from_directory(tmp_path)
29+
assert dataset_info == reloaded
30+
assert os.path.exists(os.path.join(tmp_path, "dataset_info.json"))
31+
32+
33+
def test_dataset_info_to_yaml_dict():
34+
dataset_info = DatasetInfo(
35+
description="foo",
36+
citation="bar",
37+
homepage="https://foo.bar",
38+
license="CC0",
39+
features=Features({"a": Value("int32")}),
40+
post_processed={},
41+
supervised_keys=tuple(),
42+
task_templates=[],
43+
builder_name="builder",
44+
config_name="config",
45+
version="1.0.0",
46+
splits=[{"name": "train", "num_examples": 42}],
47+
download_checksums={},
48+
download_size=1337,
49+
post_processing_size=442,
50+
dataset_size=1234,
51+
size_in_bytes=1337 + 442 + 1234,
52+
)
53+
dataset_info_yaml_dict = dataset_info._to_yaml_dict()
54+
assert sorted(dataset_info_yaml_dict) == sorted(DatasetInfo._INCLUDED_INFO_IN_YAML)
55+
for key in DatasetInfo._INCLUDED_INFO_IN_YAML:
56+
assert key in dataset_info_yaml_dict
57+
assert isinstance(dataset_info_yaml_dict[key], (list, dict, int, str))
58+
dataset_info_yaml = yaml.safe_dump(dataset_info_yaml_dict)
59+
reloaded = yaml.safe_load(dataset_info_yaml)
60+
assert dataset_info_yaml_dict == reloaded
61+
62+
63+
def test_dataset_info_to_yaml_dict_empty():
64+
dataset_info = DatasetInfo()
65+
dataset_info_yaml_dict = dataset_info._to_yaml_dict()
66+
assert dataset_info_yaml_dict == {}
67+
68+
69+
@pytest.mark.parametrize(
70+
"dataset_infos_dict",
71+
[
72+
DatasetInfosDict(),
73+
DatasetInfosDict({"default": DatasetInfo()}),
74+
DatasetInfosDict(
75+
{
76+
"default": DatasetInfo(
77+
description="foo",
78+
features=Features({"a": Value("int32")}),
79+
builder_name="builder",
80+
config_name="config",
81+
version="1.0.0",
82+
splits=[{"name": "train"}],
83+
download_size=42,
84+
)
85+
}
86+
),
87+
DatasetInfosDict(
88+
{
89+
"v1": DatasetInfo(dataset_size=42),
90+
"v2": DatasetInfo(dataset_size=1337),
91+
}
92+
),
93+
],
94+
)
95+
def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: DatasetInfosDict):
96+
tmp_path = str(tmp_path)
97+
dataset_infos_dict.write_to_directory(tmp_path)
98+
reloaded = DatasetInfosDict.from_directory(tmp_path)
99+
100+
# the config_name of the dataset_infos_dict take over the attribute
101+
for config_name, dataset_info in dataset_infos_dict.items():
102+
dataset_info.config_name = config_name
103+
# the yaml representation doesn't include fields like description or citation
104+
# so we just test that we can recover what we can from the yaml
105+
dataset_infos_dict[config_name] = DatasetInfo._from_yaml_dict(dataset_info._to_yaml_dict())
106+
assert dataset_infos_dict == reloaded
107+
108+
if dataset_infos_dict:
109+
assert os.path.exists(os.path.join(tmp_path, "README.md"))

tests/test_splits.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
from datasets.splits import SplitDict, SplitInfo
4+
from datasets.utils.py_utils import asdict
5+
6+
7+
@pytest.mark.parametrize(
8+
"split_dict",
9+
[
10+
SplitDict(),
11+
SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42, dataset_name="my_dataset")}),
12+
SplitDict({"train": SplitInfo(name="train", num_bytes=1337, num_examples=42)}),
13+
SplitDict({"train": SplitInfo()}),
14+
],
15+
)
16+
def test_split_dict_to_yaml_list(split_dict: SplitDict):
17+
split_dict_yaml_list = split_dict._to_yaml_list()
18+
assert len(split_dict_yaml_list) == len(split_dict)
19+
reloaded = SplitDict._from_yaml_list(split_dict_yaml_list)
20+
for split_name, split_info in split_dict.items():
21+
# dataset_name field is deprecated, and is therefore not part of the YAML dump
22+
split_info.dataset_name = None
23+
# the split name of split_dict takes over the name of the split info object
24+
split_info.name = split_name
25+
assert split_dict == reloaded
26+
27+
28+
@pytest.mark.parametrize(
29+
"split_info", [SplitInfo(), SplitInfo(dataset_name=None), SplitInfo(dataset_name="my_dataset")]
30+
)
31+
def test_split_dict_asdict_has_dataset_name(split_info):
32+
# For backward compatibility, we need asdict(split_dict) to return split info dictrionaries with the "dataset_name"
33+
# field even if it's deprecated. This way old versionso of `datasets` can still reload dataset_infos.json files
34+
split_dict_asdict = asdict(SplitDict({"train": split_info}))
35+
assert "dataset_name" in split_dict_asdict["train"]
36+
assert split_dict_asdict["train"]["dataset_name"] == split_info.dataset_name

0 commit comments

Comments
 (0)