Skip to content

Commit aa49caa

Browse files
committed
Support for python 3.14
Fixes #5027
1 parent 7710c30 commit aa49caa

File tree

6 files changed

+84
-272
lines changed

6 files changed

+84
-272
lines changed

.github/workflows/flax_test.yml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
runs-on: ubuntu-latest
6464
strategy:
6565
matrix:
66-
python-version: ['3.11', '3.12', '3.13']
66+
python-version: ['3.11', '3.12', '3.13', '3.14']
6767
steps:
6868
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
6969
- name: Set up Python ${{ matrix.python-version }}
@@ -170,3 +170,29 @@ jobs:
170170
"description": "'$status'",
171171
"context": "github-actions/Build"
172172
}'
173+
174+
# This is a temporary workflow to test flax on Python 3.14 and
175+
# skipping deps like tensorstore, tensorflow etc
176+
tests-python314:
177+
name: Run Tests on Python 3.14
178+
needs: [pre-commit, commit-count]
179+
runs-on: ubuntu-24.04-16core
180+
steps:
181+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
182+
- name: Setup uv
183+
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
184+
with:
185+
version: "0.9.2"
186+
python-version: "3.14"
187+
activate-environment: true
188+
enable-cache: true
189+
190+
- name: Install dependencies
191+
run: |
192+
rm -fr .venv
193+
uv sync --extra testing --extra docs
194+
- name: Test with pytest
195+
run: |
196+
export XLA_FLAGS='--xla_force_host_platform_device_count=4'
197+
find tests/ -name "*.py" | grep -vE 'io_test|tensorboard' | xargs pytest -n auto
198+

flax/linen/kw_only_dataclasses.py

Lines changed: 8 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -12,230 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Support for keyword-only fields in dataclasses for Python versions <3.10.
15+
"""This module is kept for backward compatibility.
1616
17-
This module provides wrappers for `dataclasses.dataclass` and
18-
`dataclasses.field` that simulate support for keyword-only fields for Python
19-
versions before 3.10 (which is the version where dataclasses added keyword-only
20-
field support). If this module is imported in Python 3.10+, then
21-
`kw_only_dataclasses.dataclass` and `kw_only_dataclasses.field` will simply be
22-
aliases for `dataclasses.dataclass` and `dataclasses.field`.
23-
24-
For earlier Python versions, when constructing a dataclass, any fields that have
25-
been marked as keyword-only (including inherited fields) will be moved to the
26-
end of the constuctor's argument list. This makes it possible to have a base
27-
class that defines a field with a default, and a subclass that defines a field
28-
without a default. E.g.:
29-
30-
>>> from flax.linen import kw_only_dataclasses
31-
>>> @kw_only_dataclasses.dataclass
32-
... class Parent:
33-
... name: str = kw_only_dataclasses.field(default='', kw_only=True)
34-
35-
>>> @kw_only_dataclasses.dataclass
36-
... class Child(Parent):
37-
... size: float # required.
38-
39-
>>> import inspect
40-
>>> print(inspect.signature(Child.__init__))
41-
(self, size: float, name: str = '') -> None
42-
43-
44-
(If we used `dataclasses` rather than `kw_only_dataclasses` for the above
45-
example, then it would have failed with TypeError "non-default argument
46-
'size' follows default argument.")
47-
48-
WARNING: fields marked as keyword-only will not *actually* be turned into
49-
keyword-only parameters in the constructor; they will only be moved to the
50-
end of the parameter list (after all non-keyword-only parameters).
17+
Previous code targeting Python versions <3.10 is removed and wired to
18+
built-in dataclasses module.
5119
"""
5220

5321
import dataclasses
54-
import functools
55-
import inspect
56-
from types import MappingProxyType
22+
import warnings
5723
from typing import Any, TypeVar
5824

59-
import typing_extensions as tpe
60-
6125
import flax
6226

6327
M = TypeVar('M', bound='flax.linen.Module')
6428
FieldName = str
6529
Annotation = Any
6630
Default = Any
31+
KW_ONLY = dataclasses.KW_ONLY
32+
field = dataclasses.field
33+
dataclass = dataclasses.dataclass
6734

68-
69-
class _KwOnlyType:
70-
"""Metadata tag used to tag keyword-only fields."""
71-
72-
def __repr__(self):
73-
return 'KW_ONLY'
74-
75-
76-
KW_ONLY = _KwOnlyType()
77-
78-
79-
def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs):
80-
"""Wrapper for dataclassess.field that adds support for kw_only fields.
81-
82-
Args:
83-
metadata: A mapping or None, containing metadata for the field.
84-
kw_only: If true, the field will be moved to the end of `__init__`'s
85-
parameter list.
86-
**kwargs: Keyword arguments forwarded to `dataclasses.field`
87-
88-
Returns:
89-
A `dataclasses.Field` object.
90-
"""
91-
if kw_only is not dataclasses.MISSING and kw_only:
92-
if (
93-
kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING
94-
and kwargs.get('default_factory', dataclasses.MISSING)
95-
is dataclasses.MISSING
96-
):
97-
raise ValueError('Keyword-only fields with no default are not supported.')
98-
if metadata is None:
99-
metadata = {}
100-
metadata[KW_ONLY] = True
101-
return dataclasses.field(metadata=metadata, **kwargs)
102-
103-
104-
@tpe.dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
105-
def dataclass(cls=None, extra_fields=None, **kwargs):
106-
"""Wrapper for dataclasses.dataclass that adds support for kw_only fields.
107-
108-
Args:
109-
cls: The class to transform (or none to return a decorator).
110-
extra_fields: A list of `(name, type, Field)` tuples describing extra fields
111-
that should be added to the dataclass. This is necessary for linen's
112-
use-case of this module, since the base class (linen.Module) is *not* a
113-
dataclass. In particular, linen.Module class is used as the base for both
114-
frozen and non-frozen dataclass subclasses; but the frozen status of a
115-
dataclass must match the frozen status of any base dataclasses.
116-
**kwargs: Additional arguments for `dataclasses.dataclass`.
117-
118-
Returns:
119-
`cls`.
120-
"""
121-
122-
def wrap(cls):
123-
return _process_class(cls, extra_fields=extra_fields, **kwargs)
124-
125-
return wrap if cls is None else wrap(cls)
126-
127-
128-
def _process_class(cls: type[M], extra_fields=None, **kwargs):
129-
"""Transforms `cls` into a dataclass that supports kw_only fields."""
130-
if '__annotations__' not in cls.__dict__:
131-
cls.__annotations__ = {}
132-
133-
# The original __dataclass_fields__ dicts for all base classes. We will
134-
# modify these in-place before turning `cls` into a dataclass, and then
135-
# restore them to their original values.
136-
base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()]
137-
138-
# The keyword only fields from `cls` or any of its base classes.
139-
kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {}
140-
141-
# Scan for KW_ONLY marker.
142-
kw_only_name = None
143-
for name, annotation in cls.__annotations__.items():
144-
if annotation is KW_ONLY:
145-
if kw_only_name is not None:
146-
raise TypeError('Multiple KW_ONLY markers')
147-
kw_only_name = name
148-
elif kw_only_name is not None:
149-
if not hasattr(cls, name):
150-
raise ValueError(
151-
'Keyword-only fields with no default are not supported.'
152-
)
153-
default = getattr(cls, name)
154-
if isinstance(default, dataclasses.Field):
155-
default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True})
156-
else:
157-
default = field(default=default, kw_only=True)
158-
setattr(cls, name, default)
159-
if kw_only_name:
160-
del cls.__annotations__[kw_only_name]
161-
162-
# Inject extra fields.
163-
if extra_fields:
164-
for name, annotation, default in extra_fields:
165-
if not (isinstance(name, str) and isinstance(default, dataclasses.Field)):
166-
raise ValueError(
167-
'Expected extra_fields to a be a list of '
168-
'(name, type, Field) tuples.'
169-
)
170-
setattr(cls, name, default)
171-
cls.__annotations__[name] = annotation
172-
173-
# Extract kw_only fields from base classes' __dataclass_fields__.
174-
for base in reversed(cls.__mro__[1:]):
175-
if not dataclasses.is_dataclass(base):
176-
continue
177-
base_annotations = base.__dict__.get('__annotations__', {})
178-
base_dataclass_fields[base] = dict(
179-
getattr(base, '__dataclass_fields__', {})
180-
)
181-
for base_field in list(dataclasses.fields(base)):
182-
field_name = base_field.name
183-
if base_field.metadata.get(KW_ONLY) or field_name in kw_only_fields:
184-
kw_only_fields[field_name] = (
185-
base_annotations.get(field_name),
186-
base_field,
187-
)
188-
del base.__dataclass_fields__[field_name]
189-
190-
# Remove any keyword-only fields from this class.
191-
cls_annotations = cls.__dict__['__annotations__']
192-
for name, annotation in list(cls_annotations.items()):
193-
value = getattr(cls, name, None)
194-
if (
195-
isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY)
196-
) or name in kw_only_fields:
197-
del cls_annotations[name]
198-
kw_only_fields[name] = (annotation, value)
199-
200-
# Add keyword-only fields at the end of __annotations__, in the order they
201-
# were found in the base classes and in this class.
202-
for name, (annotation, default) in kw_only_fields.items():
203-
setattr(cls, name, default)
204-
cls_annotations.pop(name, None)
205-
cls_annotations[name] = annotation
206-
207-
create_init = '__init__' not in vars(cls) and kwargs.get('init', True)
208-
209-
# Apply the dataclass transform.
210-
transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs)
211-
212-
# Restore the base classes' __dataclass_fields__.
213-
for _cls, fields in base_dataclass_fields.items():
214-
_cls.__dataclass_fields__ = fields
215-
216-
if create_init:
217-
dataclass_init = transformed_cls.__init__
218-
# use sum to count the number of init fields that are not keyword-only
219-
expected_num_args = sum(
220-
f.init and not f.metadata.get(KW_ONLY, False)
221-
for f in dataclasses.fields(transformed_cls)
222-
)
223-
224-
@functools.wraps(dataclass_init)
225-
def init_wrapper(self, *args, **kwargs):
226-
num_args = len(args)
227-
if num_args > expected_num_args:
228-
# we add + 1 to each to account for `self`, matching python's
229-
# default error message
230-
raise TypeError(
231-
f'__init__() takes {expected_num_args + 1} positional '
232-
f'arguments but {num_args + 1} were given'
233-
)
234-
235-
dataclass_init(self, *args, **kwargs)
236-
237-
init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore
238-
transformed_cls.__init__ = init_wrapper # type: ignore[method-assign]
239-
240-
# Return the transformed dataclass
241-
return transformed_cls
35+
warnings.warn("This module is deprecated, please use Python built-in dataclasses module")

flax/linen/module.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import enum
2020
import functools
2121
import inspect
22-
import sys
2322
import threading
2423
import typing
2524
import weakref
@@ -57,7 +56,6 @@
5756
union_filters,
5857
)
5958
from flax.ids import FlaxId, uuid
60-
from flax.linen import kw_only_dataclasses
6159
from flax.typing import (
6260
RNGSequences,
6361
PRNGKey,
@@ -1061,7 +1059,7 @@ def _customized_dataclass_transform(cls, kw_only: bool):
10611059
3. Generate a hash function (if not provided by cls).
10621060
"""
10631061
# Check reserved attributes have expected type annotations.
1064-
annotations = dict(cls.__dict__.get('__annotations__', {}))
1062+
annotations = inspect.get_annotations(cls)
10651063
if annotations.get('parent', _ParentType) != _ParentType:
10661064
raise errors.ReservedModuleAttributeError(annotations)
10671065
if annotations.get('name', str) not in ('str', str, Optional[str]):
@@ -1081,42 +1079,29 @@ def _customized_dataclass_transform(cls, kw_only: bool):
10811079
(
10821080
'parent',
10831081
_ParentType,
1084-
kw_only_dataclasses.field(
1082+
dataclasses.field(
10851083
repr=False, default=_unspecified_parent, kw_only=True
10861084
),
10871085
),
10881086
(
10891087
'name',
10901088
Optional[str],
1091-
kw_only_dataclasses.field(default=None, kw_only=True),
1089+
dataclasses.field(default=None, kw_only=True),
10921090
),
10931091
]
10941092

1095-
if kw_only:
1096-
if tuple(sys.version_info)[:3] >= (3, 10, 0):
1097-
for (
1098-
name,
1099-
annotation, # pytype: disable=invalid-annotation
1100-
default,
1101-
) in extra_fields:
1102-
setattr(cls, name, default)
1103-
cls.__annotations__[name] = annotation
1104-
dataclasses.dataclass( # type: ignore[call-overload]
1105-
unsafe_hash='__hash__' not in cls.__dict__,
1106-
repr=False,
1107-
kw_only=True,
1108-
)(cls)
1109-
else:
1110-
raise TypeError('`kw_only` is not available before Py 3.10.')
1111-
else:
1112-
# Now apply dataclass transform (which operates in-place).
1113-
# Do generate a hash function only if not provided by the class.
1114-
kw_only_dataclasses.dataclass(
1115-
cls,
1116-
unsafe_hash='__hash__' not in cls.__dict__,
1117-
repr=False,
1118-
extra_fields=extra_fields,
1119-
) # pytype: disable=wrong-keyword-args
1093+
for (
1094+
name,
1095+
annotation, # pytype: disable=invalid-annotation
1096+
default,
1097+
) in extra_fields:
1098+
setattr(cls, name, default)
1099+
cls.__annotations__[name] = annotation
1100+
dataclasses.dataclass( # type: ignore[call-overload]
1101+
unsafe_hash='__hash__' not in cls.__dict__,
1102+
repr=False,
1103+
kw_only=kw_only,
1104+
)(cls)
11201105

11211106
cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign]
11221107

pyproject.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ testing = [
4040
"clu",
4141
"clu<=0.0.9; python_version<'3.10'",
4242
"einops",
43-
"gymnasium[atari]",
43+
"gymnasium[atari]; python_version<'3.14'",
4444
"jaxlib",
4545
"jaxtyping",
4646
"jraph>=0.0.6dev0",
@@ -61,11 +61,11 @@ testing = [
6161
"tensorflow_text>=2.11.0; platform_system!='Darwin' and python_version < '3.13'",
6262
"tensorflow_datasets",
6363
"tensorflow>=2.12.0; python_version<'3.13'", # to fix Numpy np.bool8 deprecation error
64-
"tensorflow>=2.20.0; python_version>='3.13'",
64+
"tensorflow>=2.20.0; python_version>='3.13' and python_version<'3.14'",
6565
"torch",
6666
"treescope>=0.1.1; python_version>='3.10'",
6767
"cloudpickle>=3.0.0",
68-
"ale-py>=0.10.2",
68+
"ale-py>=0.10.2; python_version<'3.14'",
6969
]
7070
docs = [
7171
"sphinx==6.2.1",
@@ -237,3 +237,11 @@ quote-style = "single"
237237
[tool.uv]
238238
# Ignore uv.lock and always upgrade the package to the latest
239239
upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"]
240+
241+
[tool.uv.sources]
242+
torch = { index = "pytorch" }
243+
244+
[[tool.uv.index]]
245+
name = "pytorch"
246+
url = "https://download.pytorch.org/whl/cpu"
247+
explicit = true

0 commit comments

Comments
 (0)