Skip to content

Shapeless support for zeros/ones_like#2726

Merged
awni merged 3 commits intoml-explore:mainfrom
CC-Yeh:full_like
Nov 7, 2025
Merged

Shapeless support for zeros/ones_like#2726
awni merged 3 commits intoml-explore:mainfrom
CC-Yeh:full_like

Conversation

@CC-Yeh
Copy link
Copy Markdown
Contributor

@CC-Yeh CC-Yeh commented Nov 2, 2025

Proposed changes

Closes #2599

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

mlx/ops.h Outdated
}
template <typename T>
array full_like(const array& a, T val, StreamOrDevice s = {}) {
return full_like(a, array(val), to_stream(s));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the output should inherit the data type of the input here (e.g. full_like). So you would need to set array(val, a.dtype()).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! My intention with full_like(a, val) was:

  • a provides the shape

  • val provides both the fill value and the dtype

So full_like(a, val) felt like the most direct expression of “give me an array with the same shape as a, filled with val (value + type)”.

What do you think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think it could go either way. When it's somewhat ambiguous (should the type come from val or a) then I would default to the same behavior as other frameworks. Numpy (and Jax) usually are the first I check (and they tend to have the same behavior). In this case Numpy takes the type from a and not from val.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that makes sense! thanks

mlx/ops.cpp Outdated
Comment on lines +314 to +315
auto dtype = vals.dtype();
return full_like(a, std::move(vals), dtype, to_stream(s));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype here should be from a

Comment on lines +485 to +495
def test_shapeless_compile_full_like(self):
x = mx.zeros((1, 1, 32))

def zeros_fun(x):
return mx.zeros_like(x)

def ones_fun(x):
return mx.ones_like(x)

self.assertEqual(mx.compile(zeros_fun, shapeless=True)(x).shape, (1, 1, 32))
self.assertEqual(mx.compile(ones_fun, shapeless=True)(x).shape, (1, 1, 32))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of shapeless compile is that it works even if you change the shape. So it would be good to add a second call where you change the shape and make sure you get the right shape back.

@awni
Copy link
Copy Markdown
Member

awni commented Nov 4, 2025

This looks nice. Please address the comments then we can merge it.

Also could you add a test for full_like especially around making sure the various overloads return the expected types? (as they look incorrect at the moment)

@CC-Yeh CC-Yeh requested a review from awni November 4, 2025 19:45
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Nov 4, 2025

@awni Can you take a look again? Thanks!

Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! Will merge when tests clear.

@awni
Copy link
Copy Markdown
Member

awni commented Nov 6, 2025

@CC-Yeh the tests are failing. Are you able to address that?

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Nov 6, 2025

@CC-Yeh the tests are failing. Are you able to address that?

@awni Sorry a bit busy lately, will try to reproduce and fix it on colab.

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Nov 6, 2025

@awni should be fixed now!

@awni awni merged commit be9e2ae into ml-explore:main Nov 7, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make zeros_like work with shapeless compile / export.

2 participants