Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void ScaledDotProductAttention::eval_gpu(

// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
Expand Down
11 changes: 11 additions & 0 deletions python/tests/test_fast_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,17 @@ def test_sdpa_broadcast_mask(self):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))

def test_sdpa_noncontiguous_inputs(self):
mask = mx.ones(shape=(4, 1, 7, 7), dtype=mx.bool_)
mx.random.seed(0)
q = mx.random.normal(shape=(4, 7, 32, 64)).swapaxes(1, 2)

k = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
v = mx.random.normal(shape=(4, 7, 8, 64)).swapaxes(1, 2)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
ref = mlx_ref_attn(q, k, v, scale=1.0, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))

def test_sdpa_promote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Expand Down