|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -"""Support for keyword-only fields in dataclasses for Python versions <3.10. |
| 15 | +"""This module is kept for backward compatibility. |
16 | 16 |
|
17 | | -This module provides wrappers for `dataclasses.dataclass` and |
18 | | -`dataclasses.field` that simulate support for keyword-only fields for Python |
19 | | -versions before 3.10 (which is the version where dataclasses added keyword-only |
20 | | -field support). If this module is imported in Python 3.10+, then |
21 | | -`kw_only_dataclasses.dataclass` and `kw_only_dataclasses.field` will simply be |
22 | | -aliases for `dataclasses.dataclass` and `dataclasses.field`. |
23 | | -
|
24 | | -For earlier Python versions, when constructing a dataclass, any fields that have |
25 | | -been marked as keyword-only (including inherited fields) will be moved to the |
26 | | -end of the constuctor's argument list. This makes it possible to have a base |
27 | | -class that defines a field with a default, and a subclass that defines a field |
28 | | -without a default. E.g.: |
29 | | -
|
30 | | ->>> from flax.linen import kw_only_dataclasses |
31 | | ->>> @kw_only_dataclasses.dataclass |
32 | | -... class Parent: |
33 | | -... name: str = kw_only_dataclasses.field(default='', kw_only=True) |
34 | | -
|
35 | | ->>> @kw_only_dataclasses.dataclass |
36 | | -... class Child(Parent): |
37 | | -... size: float # required. |
38 | | -
|
39 | | ->>> import inspect |
40 | | ->>> print(inspect.signature(Child.__init__)) |
41 | | -(self, size: float, name: str = '') -> None |
42 | | -
|
43 | | -
|
44 | | -(If we used `dataclasses` rather than `kw_only_dataclasses` for the above |
45 | | -example, then it would have failed with TypeError "non-default argument |
46 | | -'size' follows default argument.") |
47 | | -
|
48 | | -WARNING: fields marked as keyword-only will not *actually* be turned into |
49 | | -keyword-only parameters in the constructor; they will only be moved to the |
50 | | -end of the parameter list (after all non-keyword-only parameters). |
| 17 | +Previous code targeting Python versions <3.10 is removed and wired to |
| 18 | +built-in dataclasses module. |
51 | 19 | """ |
52 | 20 |
|
53 | 21 | import dataclasses |
54 | | -import functools |
55 | | -import inspect |
56 | | -from types import MappingProxyType |
| 22 | +import warnings |
57 | 23 | from typing import Any, TypeVar |
58 | 24 |
|
59 | | -import typing_extensions as tpe |
60 | | - |
61 | 25 | import flax |
62 | 26 |
|
63 | 27 | M = TypeVar('M', bound='flax.linen.Module') |
64 | 28 | FieldName = str |
65 | 29 | Annotation = Any |
66 | 30 | Default = Any |
| 31 | +KW_ONLY = dataclasses.KW_ONLY |
| 32 | +field = dataclasses.field |
| 33 | +dataclass = dataclasses.dataclass |
67 | 34 |
|
68 | | - |
69 | | -class _KwOnlyType: |
70 | | - """Metadata tag used to tag keyword-only fields.""" |
71 | | - |
72 | | - def __repr__(self): |
73 | | - return 'KW_ONLY' |
74 | | - |
75 | | - |
76 | | -KW_ONLY = _KwOnlyType() |
77 | | - |
78 | | - |
79 | | -def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): |
80 | | - """Wrapper for dataclassess.field that adds support for kw_only fields. |
81 | | -
|
82 | | - Args: |
83 | | - metadata: A mapping or None, containing metadata for the field. |
84 | | - kw_only: If true, the field will be moved to the end of `__init__`'s |
85 | | - parameter list. |
86 | | - **kwargs: Keyword arguments forwarded to `dataclasses.field` |
87 | | -
|
88 | | - Returns: |
89 | | - A `dataclasses.Field` object. |
90 | | - """ |
91 | | - if kw_only is not dataclasses.MISSING and kw_only: |
92 | | - if ( |
93 | | - kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING |
94 | | - and kwargs.get('default_factory', dataclasses.MISSING) |
95 | | - is dataclasses.MISSING |
96 | | - ): |
97 | | - raise ValueError('Keyword-only fields with no default are not supported.') |
98 | | - if metadata is None: |
99 | | - metadata = {} |
100 | | - metadata[KW_ONLY] = True |
101 | | - return dataclasses.field(metadata=metadata, **kwargs) |
102 | | - |
103 | | - |
104 | | -@tpe.dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] |
105 | | -def dataclass(cls=None, extra_fields=None, **kwargs): |
106 | | - """Wrapper for dataclasses.dataclass that adds support for kw_only fields. |
107 | | -
|
108 | | - Args: |
109 | | - cls: The class to transform (or none to return a decorator). |
110 | | - extra_fields: A list of `(name, type, Field)` tuples describing extra fields |
111 | | - that should be added to the dataclass. This is necessary for linen's |
112 | | - use-case of this module, since the base class (linen.Module) is *not* a |
113 | | - dataclass. In particular, linen.Module class is used as the base for both |
114 | | - frozen and non-frozen dataclass subclasses; but the frozen status of a |
115 | | - dataclass must match the frozen status of any base dataclasses. |
116 | | - **kwargs: Additional arguments for `dataclasses.dataclass`. |
117 | | -
|
118 | | - Returns: |
119 | | - `cls`. |
120 | | - """ |
121 | | - |
122 | | - def wrap(cls): |
123 | | - return _process_class(cls, extra_fields=extra_fields, **kwargs) |
124 | | - |
125 | | - return wrap if cls is None else wrap(cls) |
126 | | - |
127 | | - |
128 | | -def _process_class(cls: type[M], extra_fields=None, **kwargs): |
129 | | - """Transforms `cls` into a dataclass that supports kw_only fields.""" |
130 | | - if '__annotations__' not in cls.__dict__: |
131 | | - cls.__annotations__ = {} |
132 | | - |
133 | | - # The original __dataclass_fields__ dicts for all base classes. We will |
134 | | - # modify these in-place before turning `cls` into a dataclass, and then |
135 | | - # restore them to their original values. |
136 | | - base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()] |
137 | | - |
138 | | - # The keyword only fields from `cls` or any of its base classes. |
139 | | - kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {} |
140 | | - |
141 | | - # Scan for KW_ONLY marker. |
142 | | - kw_only_name = None |
143 | | - for name, annotation in cls.__annotations__.items(): |
144 | | - if annotation is KW_ONLY: |
145 | | - if kw_only_name is not None: |
146 | | - raise TypeError('Multiple KW_ONLY markers') |
147 | | - kw_only_name = name |
148 | | - elif kw_only_name is not None: |
149 | | - if not hasattr(cls, name): |
150 | | - raise ValueError( |
151 | | - 'Keyword-only fields with no default are not supported.' |
152 | | - ) |
153 | | - default = getattr(cls, name) |
154 | | - if isinstance(default, dataclasses.Field): |
155 | | - default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True}) |
156 | | - else: |
157 | | - default = field(default=default, kw_only=True) |
158 | | - setattr(cls, name, default) |
159 | | - if kw_only_name: |
160 | | - del cls.__annotations__[kw_only_name] |
161 | | - |
162 | | - # Inject extra fields. |
163 | | - if extra_fields: |
164 | | - for name, annotation, default in extra_fields: |
165 | | - if not (isinstance(name, str) and isinstance(default, dataclasses.Field)): |
166 | | - raise ValueError( |
167 | | - 'Expected extra_fields to a be a list of ' |
168 | | - '(name, type, Field) tuples.' |
169 | | - ) |
170 | | - setattr(cls, name, default) |
171 | | - cls.__annotations__[name] = annotation |
172 | | - |
173 | | - # Extract kw_only fields from base classes' __dataclass_fields__. |
174 | | - for base in reversed(cls.__mro__[1:]): |
175 | | - if not dataclasses.is_dataclass(base): |
176 | | - continue |
177 | | - base_annotations = base.__dict__.get('__annotations__', {}) |
178 | | - base_dataclass_fields[base] = dict( |
179 | | - getattr(base, '__dataclass_fields__', {}) |
180 | | - ) |
181 | | - for base_field in list(dataclasses.fields(base)): |
182 | | - field_name = base_field.name |
183 | | - if base_field.metadata.get(KW_ONLY) or field_name in kw_only_fields: |
184 | | - kw_only_fields[field_name] = ( |
185 | | - base_annotations.get(field_name), |
186 | | - base_field, |
187 | | - ) |
188 | | - del base.__dataclass_fields__[field_name] |
189 | | - |
190 | | - # Remove any keyword-only fields from this class. |
191 | | - cls_annotations = cls.__dict__['__annotations__'] |
192 | | - for name, annotation in list(cls_annotations.items()): |
193 | | - value = getattr(cls, name, None) |
194 | | - if ( |
195 | | - isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY) |
196 | | - ) or name in kw_only_fields: |
197 | | - del cls_annotations[name] |
198 | | - kw_only_fields[name] = (annotation, value) |
199 | | - |
200 | | - # Add keyword-only fields at the end of __annotations__, in the order they |
201 | | - # were found in the base classes and in this class. |
202 | | - for name, (annotation, default) in kw_only_fields.items(): |
203 | | - setattr(cls, name, default) |
204 | | - cls_annotations.pop(name, None) |
205 | | - cls_annotations[name] = annotation |
206 | | - |
207 | | - create_init = '__init__' not in vars(cls) and kwargs.get('init', True) |
208 | | - |
209 | | - # Apply the dataclass transform. |
210 | | - transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs) |
211 | | - |
212 | | - # Restore the base classes' __dataclass_fields__. |
213 | | - for _cls, fields in base_dataclass_fields.items(): |
214 | | - _cls.__dataclass_fields__ = fields |
215 | | - |
216 | | - if create_init: |
217 | | - dataclass_init = transformed_cls.__init__ |
218 | | - # use sum to count the number of init fields that are not keyword-only |
219 | | - expected_num_args = sum( |
220 | | - f.init and not f.metadata.get(KW_ONLY, False) |
221 | | - for f in dataclasses.fields(transformed_cls) |
222 | | - ) |
223 | | - |
224 | | - @functools.wraps(dataclass_init) |
225 | | - def init_wrapper(self, *args, **kwargs): |
226 | | - num_args = len(args) |
227 | | - if num_args > expected_num_args: |
228 | | - # we add + 1 to each to account for `self`, matching python's |
229 | | - # default error message |
230 | | - raise TypeError( |
231 | | - f'__init__() takes {expected_num_args + 1} positional ' |
232 | | - f'arguments but {num_args + 1} were given' |
233 | | - ) |
234 | | - |
235 | | - dataclass_init(self, *args, **kwargs) |
236 | | - |
237 | | - init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore |
238 | | - transformed_cls.__init__ = init_wrapper # type: ignore[method-assign] |
239 | | - |
240 | | - # Return the transformed dataclass |
241 | | - return transformed_cls |
| 35 | +warnings.warn("This module is deprecated, please use Python built-in dataclasses module") |
0 commit comments