Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions mlx/backend/cuda/gemms/cublas_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,19 +248,33 @@ void CublasGemm::run(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return;
}

encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);

execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
}

void CublasGemm::run(
Expand Down
6 changes: 4 additions & 2 deletions mlx/backend/cuda/gemms/cublas_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha = 1.0f);

void run(
cu::CommandEncoder& encoder,
Expand All @@ -87,7 +88,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha);

void run_batched(
cu::CommandEncoder& encoder,
Expand Down
6 changes: 4 additions & 2 deletions mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
Expand All @@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
nullptr,
alpha);
a_it.step();
b_it.step();
}
Expand Down
6 changes: 4 additions & 2 deletions mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
Expand Down Expand Up @@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
nullptr,
alpha);
}

void CublasGemm::run_batched(
Expand Down
9 changes: 6 additions & 3 deletions mlx/backend/cuda/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ void gemm_and_bias(
array& out,
const array& a,
const array& b,
void* bias = nullptr) {
void* bias = nullptr,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);

Expand Down Expand Up @@ -94,7 +95,8 @@ void gemm_and_bias(
if (bias) {
gemm.set_bias(bias);
}
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}

} // namespace
Expand Down Expand Up @@ -169,7 +171,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
out,
a,
b,
c.data<void>());
c.data<void>(),
alpha_);
return;
}

Expand Down
226 changes: 112 additions & 114 deletions python/tests/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,124 +594,123 @@ def test_addmm(self):
np.random.seed(0)
# Batched matmul
alpha = 0.5
beta = 2.0
for beta in (1.0, 2.0):
# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))

# c must broadcast to the output shape
with self.assertRaises(ValueError):
mx.addmm(mx.zeros((2, 2, 2)), mx.zeros((2, 2)), mx.zeros((2, 2)))
# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)

# Regular batched case
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)

a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)

for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)

d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))

self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)

# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)

for c_shape in ((1,), (32, 1, 128), (1, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))

b_np_t = np.transpose(b_npy, (0, 2, 1))
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)

d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)

self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)

a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)

for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)

d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)

self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)

for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)

d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)

self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))

# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)

for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)

d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)

self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))

# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
# Split K specializtion
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)

a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)

for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)

d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)

self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))

# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
expected = beta * a + alpha * (b @ a)
self.assertTrue(mx.allclose(expected, out))

# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
expected = beta * c + alpha * (a @ b)
self.assertTrue(mx.allclose(expected, out))

def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
Expand All @@ -724,33 +723,32 @@ def make_addmm(alpha, beta):
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))

alpha = 2.0
beta = 0.5

f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)

for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))
for beta in (1.0, 0.5):
f_test = make_addmm(alpha, beta)
f_ref = make_ref_addmm(alpha, beta)

out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)
for B, M, N, K in shapes:
cotan = mx.ones((B, M, N))
c = mx.random.normal((B, M, N))
a = mx.random.normal((B, M, K))
b = mx.random.normal((B, K, N))

out_ref, dout_ref = mx.vjp(
f_ref,
[c, a, b],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[c, a, b],
[cotan],
)

self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())

for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())

def test_empty_matmul(self):
a = mx.array([[], []]).T
Expand Down