jax.numpy dtypes deeply violate the principle of least surprise
#32871
Replies: 1 comment 5 replies
-
|
Unfortunately, this was a design decision made for us long before JAX existed.
It gets a bit more complicated when you realize that unlike NumPy, JAX doesn't have special types for scalars, but rather represents scalars as zero-dimensional arrays. So The result of all these requirements is the surprising implementation details you bring up – we could remove those surprises, but it would break JAX's equivalence with NumPy APIs. That would be painful and confusing enough for users that I can't see us ever going down that route. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I didn't open this as an issue since it is not necessarily about a bug in JAX, but about a design decision that I believe should not be carried over into an eventual 1.0 release.
Specifically, it is about JAX dtypes pretending to be numpy dtypes by emulating their hash:
jax/jax/_src/numpy/scalar_types.py
Lines 37 to 41 in 11befd4
Overriding the hash to act as if they are the same breaks all kinds of expectations and assumptions made by Python itself and third-party typecheckers, leading to all kinds of completely unintuitive behavior. For example, the python docs state:
jax.numpydtypes violate both parts of this statement ("equivalent to X | Y", and "either X or Y") as a result of their hash-hacking.Since Python thinks they are the same type, the
|operator deduplicates them to whichever type is mentioned first. This breaks the commutativity of the operator:And it breaks the equivalency of
typing.Unionand the|operator, since Python seems to perform some form of caching based on what hashes are meant to express:Here, the behavior even depends on the ordering of the statements; if the two lines are switched, both reduce to
numpy.float32instead.Now this would maybe be not so bad if
jnp.float32andnp.float32were actually the same, but they are fundamentally not. JAX dtypes aren't numpy dtypes, not even the standard ones likefloat32. They don't have the same MRO, they are instances of different metaclasses, and they respond differently to different checks.So the hash-hacking completely breaks runtime typechecking. For example, with typeguard:
With beartype:
And even other basic checks:
I'm aware that
jnp.issubdtypeexists, but this behavior just doesn't make sense.More context and examples on the downstream consequences of this choice is contained in this discussion thread.
Beta Was this translation helpful? Give feedback.
All reactions