Skip to content

Commit 6c398c1

Browse files
authored
Fix to dict conversion of DatasetInfo/Features (#4741)
* Add custom asdict * Add test * One more test * Comment
1 parent fcfcc95 commit 6c398c1

File tree

8 files changed

+74
-13
lines changed

8 files changed

+74
-13
lines changed

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from collections import Counter, UserDict
2828
from collections.abc import Mapping
2929
from copy import deepcopy
30-
from dataclasses import asdict
3130
from functools import partial, wraps
3231
from io import BytesIO
3332
from math import ceil, floor
@@ -95,7 +94,7 @@
9594
from .utils._hf_hub_fixes import create_repo
9695
from .utils.file_utils import _retry, cached_path, estimate_dataset_size, hf_hub_url
9796
from .utils.info_utils import is_small_dataset
98-
from .utils.py_utils import convert_file_size_to_int, unique_values
97+
from .utils.py_utils import asdict, convert_file_size_to_int, unique_values
9998
from .utils.stratify import stratified_shuffle_split_generate_indices
10099
from .utils.tf_utils import minimal_tf_collate_fn
101100
from .utils.typing import PathLike

src/datasets/arrow_writer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import json
1717
import os
1818
import sys
19-
from dataclasses import asdict
2019
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
2120

2221
import numpy as np
@@ -39,7 +38,7 @@
3938
from .table import array_cast, cast_array_to_feature, table_cast
4039
from .utils import logging
4140
from .utils.file_utils import hash_url_to_filename
42-
from .utils.py_utils import first_non_null_value
41+
from .utils.py_utils import asdict, first_non_null_value
4342

4443

4544
logger = logging.get_logger(__name__)

src/datasets/features/features.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
import sys
2121
from collections.abc import Iterable, Mapping
22-
from dataclasses import InitVar, _asdict_inner, dataclass, field, fields
22+
from dataclasses import InitVar, dataclass, field, fields
2323
from functools import reduce, wraps
2424
from operator import mul
2525
from typing import Any, ClassVar, Dict, List, Optional
@@ -37,7 +37,7 @@
3737
from .. import config
3838
from ..table import array_cast
3939
from ..utils import logging
40-
from ..utils.py_utils import first_non_null_value, zip_dict
40+
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
4141
from .audio import Audio
4242
from .image import Image, encode_pil_image
4343
from .translation import Translation, TranslationVariableLanguages
@@ -1598,7 +1598,7 @@ def from_dict(cls, dic) -> "Features":
15981598
return cls(**obj)
15991599

16001600
def to_dict(self):
1601-
return _asdict_inner(self, dict)
1601+
return asdict(self)
16021602

16031603
def encode_example(self, example):
16041604
"""

src/datasets/fingerprint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import shutil
66
import tempfile
77
import weakref
8-
from dataclasses import asdict
98
from functools import wraps
109
from pathlib import Path
1110
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -19,7 +18,7 @@
1918
from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table
2019
from .utils.deprecation_utils import deprecated
2120
from .utils.logging import get_logger
22-
from .utils.py_utils import dumps
21+
from .utils.py_utils import asdict, dumps
2322

2423

2524
if TYPE_CHECKING:

src/datasets/info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import dataclasses
3333
import json
3434
import os
35-
from dataclasses import asdict, dataclass, field
35+
from dataclasses import dataclass, field
3636
from typing import Dict, List, Optional, Union
3737

3838
from . import config
@@ -41,7 +41,7 @@
4141
from .tasks import TaskTemplate, task_template_from_dict
4242
from .utils import Version
4343
from .utils.logging import get_logger
44-
from .utils.py_utils import unique_values
44+
from .utils.py_utils import asdict, unique_values
4545

4646

4747
logger = get_logger(__name__)

src/datasets/utils/py_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919

2020
import contextlib
21+
import copy
2122
import functools
2223
import itertools
2324
import os
@@ -26,6 +27,7 @@
2627
import sys
2728
import types
2829
from contextlib import contextmanager
30+
from dataclasses import fields, is_dataclass
2931
from io import BytesIO as StringIO
3032
from multiprocessing import Pool, RLock
3133
from shutil import disk_usage
@@ -151,6 +153,41 @@ def string_to_dict(string: str, pattern: str) -> Dict[str, str]:
151153
return _dict
152154

153155

156+
def asdict(obj):
157+
"""Convert an object to its dictionary representation recursively."""
158+
159+
# Implementation based on https://docs.python.org/3/library/dataclasses.html#dataclasses.asdict
160+
161+
def _is_dataclass_instance(obj):
162+
# https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass
163+
return is_dataclass(obj) and not isinstance(obj, type)
164+
165+
def _asdict_inner(obj):
166+
if _is_dataclass_instance(obj):
167+
result = {}
168+
for f in fields(obj):
169+
value = _asdict_inner(getattr(obj, f.name))
170+
result[f.name] = value
171+
return result
172+
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
173+
# obj is a namedtuple
174+
return type(obj)(*[_asdict_inner(v) for v in obj])
175+
elif isinstance(obj, (list, tuple)):
176+
# Assume we can create an object of this type by passing in a
177+
# generator (which is not true for namedtuples, handled
178+
# above).
179+
return type(obj)(_asdict_inner(v) for v in obj)
180+
elif isinstance(obj, dict):
181+
return {_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()}
182+
else:
183+
return copy.deepcopy(obj)
184+
185+
if not isinstance(obj, dict) and not _is_dataclass_instance(obj):
186+
raise TypeError(f"{obj} is not a dict or a dataclass")
187+
188+
return _asdict_inner(obj)
189+
190+
154191
@contextlib.contextmanager
155192
def temporary_assignment(obj, attr, value):
156193
"""Temporarily assign obj.attr to value."""

tests/features/test_features.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import datetime
2-
from dataclasses import asdict
32
from unittest import TestCase
43
from unittest.mock import patch
54

@@ -20,6 +19,7 @@
2019
)
2120
from datasets.features.translation import Translation, TranslationVariableLanguages
2221
from datasets.info import DatasetInfo
22+
from datasets.utils.py_utils import asdict
2323

2424
from ..utils import require_jax, require_tf, require_torch
2525

@@ -101,6 +101,13 @@ def test_feature_named_type(self):
101101
reloaded_features = Features.from_dict(asdict(ds_info)["features"])
102102
assert features == reloaded_features
103103

104+
def test_class_label_feature_with_no_labels(self):
105+
"""reference: issue #4681"""
106+
features = Features({"label": ClassLabel(names=[])})
107+
ds_info = DatasetInfo(features=features)
108+
reloaded_features = Features.from_dict(asdict(ds_info)["features"])
109+
assert features == reloaded_features
110+
104111
def test_reorder_fields_as(self):
105112
features = Features(
106113
{

tests/test_py_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from dataclasses import dataclass
12
from unittest import TestCase
23

34
import numpy as np
45
import pytest
56

6-
from datasets.utils.py_utils import NestedDataStructure, map_nested, temp_seed, temporary_assignment, zip_dict
7+
from datasets.utils.py_utils import NestedDataStructure, asdict, map_nested, temp_seed, temporary_assignment, zip_dict
78

89
from .utils import require_tf, require_torch
910

@@ -16,6 +17,12 @@ def add_one(i): # picklable for multiprocessing
1617
return i + 1
1718

1819

20+
@dataclass
21+
class A:
22+
x: int
23+
y: str
24+
25+
1926
class PyUtilsTest(TestCase):
2027
def test_map_nested(self):
2128
s1 = {}
@@ -175,3 +182,16 @@ def test_nested_data_structure_data(input_data):
175182
def test_flatten(data, expected_output):
176183
output = NestedDataStructure(data).flatten()
177184
assert output == expected_output
185+
186+
187+
def test_asdict():
188+
input = A(x=1, y="foobar")
189+
expected_output = {"x": 1, "y": "foobar"}
190+
assert asdict(input) == expected_output
191+
192+
input = {"a": {"b": A(x=10, y="foo")}, "c": [A(x=20, y="bar")]}
193+
expected_output = {"a": {"b": {"x": 10, "y": "foo"}}, "c": [{"x": 20, "y": "bar"}]}
194+
assert asdict(input) == expected_output
195+
196+
with pytest.raises(TypeError):
197+
asdict([1, A(x=10, y="foo")])

0 commit comments

Comments
 (0)