Conversation
302eefa to
6fb7d08
Compare
|
This looks like a very nice improvement. There are two issues we need to work through:
Are you up for making those two fixes? If you are unable to make the CUDA change we can help with that, let us know. |
Hope the latest commit will fix this.
I’d love to add it. I don’t have access to an NVIDIA GPU, so I can’t test CUDA locally. |
|
Hi Dan - could you DM me and we can discuss work-arounds? You can reach me on X or at awni at apple.com. |
|
I was planning to work on this issue as well but @Dan-Yeh beat me to it lol.
Anyway, I benchmarked the CPU backend on my M2 Air ( Observations:
I don't suppose it makes sense to retain the Karatsuba approach for CPU. In hindsight, it makes sense that cblas would run faster than the naive Karatsuba decomposition but I'm glad I could actually benchmark it for my curiosity. CC: @awni |
I think |
|
Oh ok.. in that case we may just need to update the routing conditions so we don't route to the custom gemv when the input has complex type. |
f632bbc to
a210b14
Compare
26c9421 to
2a2f6b1
Compare
|
Hey @awni, can you take a look again? I believe the CPU, Metal, and CUDA parts are in good shape. One note: CUDA::gemm runs slower than the prior ops-based route and could use optimization by a CUDA specialist. |
It's most likely because it's no longer hitting tensor cores 🤔 . I am not really sure what's expected there when using complex. Tensor cores do matmuls in lower precision (tf32). |
mlx/backend/cuda/matmul.cpp
Outdated
| array bias_arr = astype(*bias, out.dtype(), s); | ||
| out = add(out, bias_arr, s); |
There was a problem hiding this comment.
This is problematic here. We can't do MLX ops inside a primitive's evaluation.
Instead of doing the fallback this way, just fallback to the full addmm. So the condition could be
if (bias && a.dtype() != complex64)
|
Hey @Dan-Yeh this is looking good. Could you please rebase, address the latest remarks and then I will run the tests? |
d558ac7 to
85d4b0e
Compare
|
@awni Done! |
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
|
Hi @awni |
|
Thank you @Dan-Yeh . I will do some more checks and merge tonight. Thanks for the awesome PR and for your patience. |
angeloskath
left a comment
There was a problem hiding this comment.
It's most likely because it's no longer hitting tensor cores 🤔 .
I enabled TF32 for the complex gemm and now it is 2x faster than before 🚀.
Thanks @Dan-Yeh for the awesome PR, I will merge after the tests clear.




Proposed changes
Introduce
cblas_cgemmMetal: Makegemvcompatible withcomplex64_tMetal: Addcomplex64BlockMMAspecialization to simplifygemmintegration.Cuda: Makegemvandgemmcompatible withcomplex64_t.Only tuned on small chip, will need people with larger chips to tune the tile size.
Closes #2076
##Benchmarks
Metal
Average 6x faster for new
gemv(bench_gemv.py)Average 1.5x faster for new
gemm(bench_gemm.py)CUDA
Average 6x faster for new


gemvgemm~1.7 times lower, need help from someone with CUDA expertise
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes