Skip to content

Commit 1014b25

Browse files
committed
Added workaround by P.Hawkins
1 parent 0ad978f commit 1014b25

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

.github/workflows/flax_test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ jobs:
191191
run: |
192192
rm -fr .venv
193193
uv sync --extra testing --extra docs
194+
# temporary: install jax nightly
195+
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
194196
- name: Test with pytest
195197
run: |
196198
export XLA_FLAGS='--xla_force_host_platform_device_count=4'

flax/linen/module.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import enum
2020
import functools
2121
import inspect
22+
import sys
2223
import threading
2324
import typing
2425
import weakref
@@ -1097,6 +1098,12 @@ def _customized_dataclass_transform(cls, kw_only: bool):
10971098
) in extra_fields:
10981099
setattr(cls, name, default)
10991100
cls.__annotations__[name] = annotation
1101+
1102+
# TODO: a workaround for the issue:
1103+
# https://github.com/google/flax/pull/5087#issuecomment-3536610568
1104+
if (sys.version_info.major, sys.version_info.minor) in [(3, 12), (3, 13)]:
1105+
setattr(cls, '__annotations__', cls.__annotations__)
1106+
11001107
dataclasses.dataclass( # type: ignore[call-overload]
11011108
unsafe_hash='__hash__' not in cls.__dict__,
11021109
repr=False,

tests/linen/linen_module_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,27 @@ class Foo(nn.Module):
23152315
Foo(1, None)
23162316
Foo(a=1, parent=None) # type: ignore[call-arg]
23172317

2318+
def test_failure_with_sequencelayer(self):
2319+
# This is a minimal reproducer of the failure seen with
2320+
# SequenceLayer project and Flax Linen when enabled support for 3.14
2321+
# See PR: https://github.com/google/flax/pull/5087
2322+
# Code below is based on
2323+
# https://github.com/google/flax/pull/5087#issuecomment-3535067361
2324+
import abc
2325+
from collections.abc import Iterator
2326+
from typing import Protocol
2327+
2328+
class CheckpointableIterator(Iterator, Protocol):
2329+
pass
2330+
2331+
class Steppable(metaclass=abc.ABCMeta):
2332+
pass
2333+
2334+
isinstance(Steppable, Iterator)
2335+
2336+
class SequenceLayer(nn.Module, Steppable):
2337+
pass
2338+
23182339
def test_module_path_empty(self):
23192340
rngkey = jax.random.key(0)
23202341
scope = Scope({}, {'params': rngkey}, mutable=['params'])

0 commit comments

Comments
 (0)