Commit 893a660
revise axes_scan to flatten argument pytrees only once
A user has a custom pytree node with the unusual behavior that it introduces
new arrays when flattening. That is, it's as if we had:
```python
# a custom object with two leaf arrays
custom_tree_object = SomeObject(jax_arrray1, jax_array2)
# convert leaves to ShapedArrays
custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object)
# flatten, should only see ShapedArrays, right?
leaves, treedef = jax.tree.flatten(custom_tree_object2)
print(leaves)
# [ShapedArray(...), ShapedArray(...), np.array(...)]
```
This change makes the `flax.nn.scan` function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues.
I don't think we want to support this in general, but this fix seemed like the most
expedient way to roll fowrard jax-ml/jax#29273
PiperOrigin-RevId: 7681751181 parent 8233415 commit 893a660
1 file changed
+9
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
157 | 159 | | |
158 | | - | |
159 | 160 | | |
160 | 161 | | |
161 | 162 | | |
| |||
0 commit comments