Shapeless support for zeros/ones_like#2726
Conversation
mlx/ops.h
Outdated
| } | ||
| template <typename T> | ||
| array full_like(const array& a, T val, StreamOrDevice s = {}) { | ||
| return full_like(a, array(val), to_stream(s)); |
There was a problem hiding this comment.
I think the output should inherit the data type of the input here (e.g. full_like). So you would need to set array(val, a.dtype()).
There was a problem hiding this comment.
Thanks for the suggestion! My intention with full_like(a, val) was:
-
aprovides the shape -
valprovides both the fillvalueand thedtype
So full_like(a, val) felt like the most direct expression of “give me an array with the same shape as a, filled with val (value + type)”.
What do you think?
There was a problem hiding this comment.
Yea I think it could go either way. When it's somewhat ambiguous (should the type come from val or a) then I would default to the same behavior as other frameworks. Numpy (and Jax) usually are the first I check (and they tend to have the same behavior). In this case Numpy takes the type from a and not from val.
There was a problem hiding this comment.
Cool that makes sense! thanks
mlx/ops.cpp
Outdated
| auto dtype = vals.dtype(); | ||
| return full_like(a, std::move(vals), dtype, to_stream(s)); |
python/tests/test_compile.py
Outdated
| def test_shapeless_compile_full_like(self): | ||
| x = mx.zeros((1, 1, 32)) | ||
|
|
||
| def zeros_fun(x): | ||
| return mx.zeros_like(x) | ||
|
|
||
| def ones_fun(x): | ||
| return mx.ones_like(x) | ||
|
|
||
| self.assertEqual(mx.compile(zeros_fun, shapeless=True)(x).shape, (1, 1, 32)) | ||
| self.assertEqual(mx.compile(ones_fun, shapeless=True)(x).shape, (1, 1, 32)) |
There was a problem hiding this comment.
The point of shapeless compile is that it works even if you change the shape. So it would be good to add a second call where you change the shape and make sure you get the right shape back.
|
This looks nice. Please address the comments then we can merge it. Also could you add a test for |
|
@awni Can you take a look again? Thanks! |
awni
left a comment
There was a problem hiding this comment.
Looks great, thanks! Will merge when tests clear.
|
@CC-Yeh the tests are failing. Are you able to address that? |
|
@awni should be fixed now! |
Proposed changes
Closes #2599
Checklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes