diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 9647d0884c..cc27bff2d1 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu( // Define some copy functions to ensure the layout of the inputs is as // expected. - copies.reserve(3); + copies.reserve(inputs.size()); auto copy_unless = [&copies, &s]( auto predicate, const array& arr) -> const array& { if (!predicate(arr)) { diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 8f9cb34cfc..abc9ada9d3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -619,6 +619,17 @@ def test_sdpa_broadcast_mask(self): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_noncontiguous_inputs(self): + mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_) + mx.random.seed(0) + q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2) + + k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2) + v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_promote_mask(self): mask = mx.array(2.0, mx.bfloat16) D = 64