Skip to content

leloykun/flash-attention-minimal

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Attention Minimal

A minimal re-implementation of Flash Attention with CUDA and PyTorch. The official implementation can be quite daunting for a CUDA beginner (like myself), so this repo tries to be small and educational.

Usage

Prerequisite

  • PyTorch (with CUDA)
  • Ninja for loading in C++

Benchmark

Compare the wall-clock time between manual attention and minimal flash attention:

python bench.py

Sample output on an RTX 3060 for the forward pass (Br = Bc = 32):

=== profiling manual attention (forward pass) ===
...
Self CPU time total: 375.381ms
Self CUDA time total: 377.542ms

=== profiling minimal flash attention 1 (forward pass) ===
...
Self CPU time total: 527.162ms
Self CUDA time total: 108.211ms

=== profiling minimal flash attention 2 (forward pass) ===
...
Self CPU time total: 343.248ms
Self CUDA time total: 4.048ms

That's a 3.5x & 94x speedup for Flash Attention 1 & 2, respectively!

Sample output on an RTX 3060 for the backward pass (Br = Bc = 16):

=== profiling manual attention (backward pass) ===
...
Self CPU time total: 65.457ms
Self CUDA time total: 67.838ms

=== profiling minimal flash attention 1 (backward pass) === 
...
Self CPU time total: 1.013s
Self CUDA time total: 4.615ms

=== profiling minimal flash attention 2 (backward pass) === 
...
Self CPU time total: 1.023s
Self CUDA time total: 814.000us

That's a 15x & 83x speedup for Flash Attention 1 & 2, respectively!

I don't have a GPU

Try out this online colab demo.

Caveats

  • In the inner loop, I assign each thread to a row of the output matrix. This differs from the original implementation.
  • This thread-per-row simplification makes the matrix multiplications very slow. This is probably why for longer sequences and larger block sizes, this gets slower than the manual implementation.
  • Q,K,Vs are in float32, unlike the original implementation which uses float16.
  • The block size is fixed at compile time to 32.

Todos

  • Speed up matmults
  • Dynamically set block size

Contributors

  • Franz Cesista, Implemented the backward pass for the Flash Attention 1 algorithm & both the forward and backward passes for the Flash Attention 2 algorithm.
  • Peter Kim, Implemented the forward pass for the minimal Flash Attention 1 algorithm. See original repo

About

Flash Attention in 300-500 lines of CUDA/C++

Resources

License

Stars

Watchers

Forks

Contributors

Languages

  • Cuda 82.5%
  • Python 13.4%
  • C++ 4.1%