-
-
Notifications
You must be signed in to change notification settings - Fork 177
perf: enable module flatten/unflatten fastpath #1117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
So Equinox looks to avoid special-casing any module methods. In particular it may be the case that someone already has a method called (The style you have here is actually what we used to have back in the distant early days of Equinox, and moved away from it.) I'm not super what a better alternative is. Perhaps JAX might allow re-registering a PyTree with different flatten/unflatten rules. |
|
What about special-casing equinox-prefixed versions of these methods? |
|
@patrick-kidger if I can't crack #1119 (or even if I can, IMO it'd be nice to be able to customize) would |
|
So I'm really leaning against adding something like that to Supposing #1119 comes good, what would be your use-case? |
|
E.g. not have the |
|
And with #1119, it would be cool to show the user what the (un)flattening code is doing by attaching the generated functions to the classes! class ModuleMeta:
def __new__(...):
cls._eqx_tree_flatten, cls._eqx_tree_unflatten = generate_functions(cls)
jax.tree_util.register_stuff(cls._eqx_tree_flatten, cls._eqx_tree_unflatten)
class MyClass(eqx.Module):
attr1: float
MyClass._eqx_tree_flatten?
>>> def flatten(self):
... return (self.attr1,), () |
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
Signed-off-by: nstarman <[email protected]>
|
I've rebased this PR on #1119. New approximate timings on that performance notebook:
This PR allows for an ~88% improvement ((12.5 - 8.3) / (12.5 - 7.7)) if a power user wants to write these methods. The default (from #1119) is a ~50% improvement. |
This PR enables power users to write their own flattening/unflattening procedures.
The advantage is that
Modulecan be sped up to be as fast as custom pytree structures and jax arrays.I'm attaching a Jupyter Notebook showing that in this example we achieve a ~50% speedup, which is ~100% of the overhead, making raw arrays, simple pytrees, and
Moduleall equivalently fast.I'm happy to add documentation / tests.
overhead.ipynb.zip