Skip to content

Faster complex matmul#2571

Merged
angeloskath merged 10 commits intoml-explore:mainfrom
CC-Yeh:complex_matmul
Oct 3, 2025
Merged

Faster complex matmul#2571
angeloskath merged 10 commits intoml-explore:mainfrom
CC-Yeh:complex_matmul

Conversation

@CC-Yeh
Copy link
Copy Markdown
Contributor

@CC-Yeh CC-Yeh commented Sep 6, 2025

Proposed changes

  • Introduce cblas_cgemm

  • Metal: Make gemv compatible with complex64_t

  • Metal: Add complex64 BlockMMA specialization to simplify gemm integration.

  • Cuda : Make gemv and gemm compatible with complex64_t.

Only tuned on small chip, will need people with larger chips to tune the tile size.

Closes #2076

##Benchmarks

Metal

Average 6x faster for new gemv (bench_gemv.py)

output

Average 1.5x faster for new gemm (bench_gemm.py)

output (1)

CUDA

Average 6x faster for new gemv
output (3)
output (2)

gemm
~1.7 times lower, need help from someone with CUDA expertise
output (4)

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Copy Markdown
Member

awni commented Sep 8, 2025

This looks like a very nice improvement. There are two issues we need to work through:

  1. The JIT tests are broken on macOS. Are you able to resolve that?
  2. We need a strategy for CUDA. cuBlas has support for complex 64-bit type, so maybe we can just route to that?

Are you up for making those two fixes? If you are unable to make the CUDA change we can help with that, let us know.

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Sep 8, 2025

This looks like a very nice improvement. There are two issues we need to work through:

  1. The JIT tests are broken on macOS. Are you able to resolve that?

Hope the latest commit will fix this.

  1. We need a strategy for CUDA. cuBlas has support for complex 64-bit type, so maybe we can just route to that?

I’d love to add it. I don’t have access to an NVIDIA GPU, so I can’t test CUDA locally.

@awni
Copy link
Copy Markdown
Member

awni commented Sep 8, 2025

Hi Dan - could you DM me and we can discuss work-arounds? You can reach me on X or at awni at apple.com.

@raishish
Copy link
Copy Markdown

I was planning to work on this issue as well but @Dan-Yeh beat me to it lol.

plot_nn plot_nt plot_tn

Anyway, I benchmarked the CPU backend on my M2 Air (cblas_gemm vs the current op-based Karatsuba algorithm).

Observations:

  • Karatsuba is a bit quicker (20-40%) for some cases for nn transpose config. I suppose I can scale this further to see if this trend continues.
  • However, in case of nt and tn configs, the cblas implementation is overall much faster. My guess is since it can intrinsically account for the transpose operation, it makes for a more efficient implementation.

I don't suppose it makes sense to retain the Karatsuba approach for CPU. In hindsight, it makes sense that cblas would run faster than the naive Karatsuba decomposition but I'm glad I could actually benchmark it for my curiosity.

CC: @awni

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Sep 13, 2025

@awni

  1. We need a strategy for CUDA. cuBlas has support for complex 64-bit type, so maybe we can just route to that?

I think CUDA_C_32F is already supported.

@CC-Yeh CC-Yeh requested a review from awni September 13, 2025 16:43
@awni
Copy link
Copy Markdown
Member

awni commented Sep 13, 2025

Oh ok.. in that case we may just need to update the routing conditions so we don't route to the custom gemv when the input has complex type.

@CC-Yeh CC-Yeh force-pushed the complex_matmul branch 15 times, most recently from f632bbc to a210b14 Compare September 17, 2025 21:51
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Sep 19, 2025

Interestingly, ops based solution can outperform cublasLt in 18/21 cases, except skinny or small matrix.

output (4)

@CC-Yeh CC-Yeh force-pushed the complex_matmul branch 3 times, most recently from 26c9421 to 2a2f6b1 Compare September 19, 2025 22:05
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Sep 19, 2025

Hey @awni, can you take a look again?

I believe the CPU, Metal, and CUDA parts are in good shape. One note: CUDA::gemm runs slower than the prior ops-based route and could use optimization by a CUDA specialist.

@awni
Copy link
Copy Markdown
Member

awni commented Sep 22, 2025

One note: CUDA::gemm runs slower than the prior ops-based route and could use optimization by a CUDA specialist.

It's most likely because it's no longer hitting tensor cores 🤔 . I am not really sure what's expected there when using complex. Tensor cores do matmuls in lower precision (tf32).

Comment on lines +111 to +112
array bias_arr = astype(*bias, out.dtype(), s);
out = add(out, bias_arr, s);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic here. We can't do MLX ops inside a primitive's evaluation.

Instead of doing the fallback this way, just fallback to the full addmm. So the condition could be

if (bias && a.dtype() != complex64)

@awni
Copy link
Copy Markdown
Member

awni commented Sep 22, 2025

Hey @Dan-Yeh this is looking good. Could you please rebase, address the latest remarks and then I will run the tests?

@CC-Yeh CC-Yeh force-pushed the complex_matmul branch 2 times, most recently from d558ac7 to 85d4b0e Compare September 22, 2025 19:12
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Sep 22, 2025

@awni Done!

@CC-Yeh CC-Yeh requested a review from awni September 23, 2025 15:36
@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented Oct 1, 2025

Hi @awni
Just removed some redundancies and rebased again.
Can you take a look? Thanks.

@angeloskath
Copy link
Copy Markdown
Member

Thank you @Dan-Yeh . I will do some more checks and merge tonight. Thanks for the awesome PR and for your patience.

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's most likely because it's no longer hitting tensor cores 🤔 .

I enabled TF32 for the complex gemm and now it is 2x faster than before 🚀.

Thanks @Dan-Yeh for the awesome PR, I will merge after the tests clear.

@angeloskath angeloskath merged commit 22a5da7 into ml-explore:main Oct 3, 2025
7 checks passed
faisalmemon pushed a commit to faisalmemon/mlx that referenced this pull request Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use cblas_cgemm for CPU complex matmul

4 participants