Skip to content

Conversation

@nstarman
Copy link
Contributor

@nstarman nstarman commented Oct 12, 2025

This PR enables power users to write their own flattening/unflattening procedures.
The advantage is that Module can be sped up to be as fast as custom pytree structures and jax arrays.
I'm attaching a Jupyter Notebook showing that in this example we achieve a ~50% speedup, which is ~100% of the overhead, making raw arrays, simple pytrees, and Module all equivalently fast.

I'm happy to add documentation / tests.

overhead.ipynb.zip

@patrick-kidger
Copy link
Owner

So Equinox looks to avoid special-casing any module methods. In particular it may be the case that someone already has a method called tree_unflatten etc, for some purpose unrelated to the tree map'ing of the module itself.

(The style you have here is actually what we used to have back in the distant early days of Equinox, and moved away from it.)

I'm not super what a better alternative is. Perhaps JAX might allow re-registering a PyTree with different flatten/unflatten rules.

@nstarman
Copy link
Contributor Author

nstarman commented Oct 12, 2025

What about special-casing equinox-prefixed versions of these methods? eqx_tree_unflatten, etc. (they could even be private _eqx_tree_unflatten). That would be easy to support for fast-paths.

@nstarman
Copy link
Contributor Author

@patrick-kidger if I can't crack #1119 (or even if I can, IMO it'd be nice to be able to customize) would _eqx_tree_unflatten be fine to add to Module?

@patrick-kidger
Copy link
Owner

So I'm really leaning against adding something like that to eqx.Module. Part of the design thesis of eqx.Module, as compared to jax.tree_util, is that custom flatten/unflatten functions are error-prone and simply aren't necessary – it suffices to just set dataclass fields instead.

Supposing #1119 comes good, what would be your use-case?

@nstarman
Copy link
Contributor Author

nstarman commented Oct 21, 2025

E.g. not have the _MISSING (/flatten_sentinel) logic if I'm sure modules will be fully initialized. Avoid the wrapper stuff if needed. Write a mypyc transpiled mixin class with the (un)flattening logic that removes most (not all, because this would need @mypyc_attr(allow_interpreted_subclasses=True)) of the python overhead. For frequently-in-hot-loop classes like unxt.Quantity it would be nice to be able to achieve JAX speeds.

@nstarman
Copy link
Contributor Author

nstarman commented Oct 21, 2025

And with #1119, it would be cool to show the user what the (un)flattening code is doing by attaching the generated functions to the classes!

class ModuleMeta:
    def __new__(...):
        cls._eqx_tree_flatten, cls._eqx_tree_unflatten = generate_functions(cls)
        jax.tree_util.register_stuff(cls._eqx_tree_flatten, cls._eqx_tree_unflatten)

class MyClass(eqx.Module):
    attr1: float

MyClass._eqx_tree_flatten?
>>> def flatten(self):
...  return (self.attr1,), ()

@nstarman
Copy link
Contributor Author

nstarman commented Nov 5, 2025

I've rebased this PR on #1119.
And I've changed it so that the (un)flattening methods start with _eqx_

New approximate timings on that performance notebook:

This PR allows for an ~88% improvement ((12.5 - 8.3) / (12.5 - 7.7)) if a power user wants to write these methods. The default (from #1119) is a ~50% improvement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants