Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,37 +280,53 @@ array copy(array a, StreamOrDevice s /* = {} */) {
{std::move(a)});
}

array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
return array(
vals.shape(),
dtype,
std::make_shared<Full>(to_stream(s)),
{astype(vals, dtype, s)});
}

array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto copied_shape = shape; // |shape| will be moved
return array(
std::move(copied_shape),
dtype,
std::make_shared<Full>(to_stream(s)),
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
return full_impl(broadcast_to(vals, std::move(shape), s), dtype, s);
}

array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
auto dtype = vals.dtype(); // |vals| will be moved
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
}

array full_like(
const array& a,
array vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
auto inputs = broadcast_arrays({a, std::move(vals)}, s);
return full_impl(std::move(inputs[1]), dtype, s);
}

array full_like(const array& a, array vals, StreamOrDevice s /* = {} */) {
return full_like(a, std::move(vals), a.dtype(), to_stream(s));
}

array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
return full(shape, array(0, dtype), to_stream(s));
}

array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
return zeros(a.shape(), a.dtype(), to_stream(s));
return full_like(a, 0, a.dtype(), to_stream(s));
}

array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
return full(shape, array(1, dtype), to_stream(s));
}

array ones_like(const array& a, StreamOrDevice s /* = {} */) {
return ones(a.shape(), a.dtype(), to_stream(s));
return full_like(a, 1, a.dtype(), to_stream(s));
}

array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
Expand Down
11 changes: 11 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ array full(Shape shape, T val, StreamOrDevice s = {}) {
return full(std::move(shape), array(val), to_stream(s));
}

array full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {});
array full_like(const array& a, array vals, StreamOrDevice s = {});
template <typename T>
array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) {
return full_like(a, array(val, dtype), dtype, to_stream(s));
}
template <typename T>
array full_like(const array& a, T val, StreamOrDevice s = {}) {
return full_like(a, array(val, a.dtype()), to_stream(s));
}

/** Fill an array of the given shape with zeros. */
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
Expand Down
1 change: 1 addition & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ class Full : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_NAME(Full)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};

class Gather : public UnaryPrimitive {
Expand Down
23 changes: 22 additions & 1 deletion python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,28 @@ def fun(x):

self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))

def test_shapeless_compile_full_like(self):
x_shape = (1, 1, 32)
x = mx.zeros((x_shape))

def zeros_fun(x):
return mx.zeros_like(x)

def ones_fun(x):
return mx.ones_like(x)

compiled_zero_like = mx.compile(zeros_fun, shapeless=True)
compiled_ones_like = mx.compile(ones_fun, shapeless=True)

self.assertEqual(compiled_zero_like(x).shape, x_shape)
self.assertEqual(compiled_ones_like(x).shape, x_shape)

y_shape = (2, 2, 16)
y = mx.zeros(y_shape)

self.assertEqual(compiled_zero_like(y).shape, y_shape)
self.assertEqual(compiled_ones_like(y).shape, y_shape)

def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
Expand Down Expand Up @@ -842,7 +864,6 @@ def fun(inputs):
self.assertTrue(mx.allclose(out, expected))

def test_compile_many_outputs(self):

@mx.compile
def fun(arr):
arrs = [arr] * 64
Expand Down
26 changes: 26 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2826,6 +2826,32 @@ TEST_CASE("test stack") {
stack({x, y}, 0), "All arrays must have the same shape and dtype");
}

TEST_CASE("test full_like") {
auto base_int = array({1, 2, 3}, {3}, int16);

auto from_array_with_dtype = full_like(base_int, array(7.5f), float16);
auto expected_float16 = array({7.5, 7.5, 7.5}, {3}, float16);
CHECK_EQ(from_array_with_dtype.dtype(), float16);
CHECK(array_equal(from_array_with_dtype, expected_float16).item<bool>());

auto from_array_default_dtype = full_like(base_int, array(4.0f));
auto expected_int16 = array({4, 4, 4}, {3}, int16);
CHECK_EQ(from_array_default_dtype.dtype(), int16);
CHECK(array_equal(from_array_default_dtype, expected_int16).item<bool>());

auto from_scalar_with_dtype = full_like(base_int, 3.25f, float32);
auto expected_float32 = array({3.25f, 3.25f, 3.25f}, {3}, float32);
CHECK_EQ(from_scalar_with_dtype.dtype(), float32);
CHECK(array_equal(from_scalar_with_dtype, expected_float32).item<bool>());

auto base_float = array({1.0f, 2.0f}, {2}, float32);
auto from_scalar_default_dtype = full_like(base_float, 2);
auto expected_base_float = array({2.0f, 2.0f}, {2}, float32);
CHECK_EQ(from_scalar_default_dtype.dtype(), float32);
CHECK(
array_equal(from_scalar_default_dtype, expected_base_float).item<bool>());
}

TEST_CASE("test eye") {
auto eye_3 = eye(3);
CHECK_EQ(eye_3.shape(), Shape{3, 3});
Expand Down