Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
7be0b44
Compress kvcache work
micmelesse Jun 11, 2024
356d243
fix causal. use cache_seqlens
micmelesse Jun 26, 2024
3e3dfc1
clean and test what works
micmelesse Jun 26, 2024
5a3cb0d
some configs work on new_kv but fails on 1,8
micmelesse Jun 27, 2024
e611433
cache overwrite correct
micmelesse Jun 28, 2024
737b701
new_kv works more or less
micmelesse Jun 28, 2024
52eb402
test local
micmelesse Jun 28, 2024
619c9ad
work on paged kv attention
micmelesse Jul 1, 2024
2d49406
prefill paged attention
micmelesse Jul 1, 2024
e5d13ef
fix has_batch_idx and skip local and rotatary emb
micmelesse Jul 2, 2024
6aa7caf
save
micmelesse Jul 2, 2024
d46e730
save
micmelesse Jul 2, 2024
63eb390
save
micmelesse Jul 3, 2024
f0193a7
save
micmelesse Jul 3, 2024
0e5223c
handle new_kv when paged kv cache
micmelesse Jul 8, 2024
962dc8a
all except has_batch_idx works
micmelesse Jul 9, 2024
b334464
major options are green
micmelesse Jul 10, 2024
0f3091c
test all
micmelesse Jul 10, 2024
c5be670
add tests
micmelesse Jul 10, 2024
10fd70b
save
micmelesse Jul 10, 2024
4c10a6b
clean up
micmelesse Jul 10, 2024
753093d
minor clean up
micmelesse Jul 10, 2024
431fd7a
simplest config
micmelesse Jul 10, 2024
70fce1e
save debug true
micmelesse Jul 10, 2024
3d73e88
save
micmelesse Jul 10, 2024
8a393f6
refactor slightly
micmelesse Jul 11, 2024
f681687
save work
micmelesse Jul 12, 2024
f6a546f
need key masking
micmelesse Jul 12, 2024
4f44741
force hip
micmelesse Jul 16, 2024
ed1cbcc
use is_hip
micmelesse Jul 17, 2024
1d46be3
save
micmelesse Jul 17, 2024
4802a37
fix cache_seq_len issue
micmelesse Jul 17, 2024
cd4617d
work on new_kv
micmelesse Jul 17, 2024
3f2b171
pass new_kv data
micmelesse Jul 18, 2024
475153f
save
micmelesse Jul 18, 2024
9864bc2
benchmark fwd only
micmelesse Jul 19, 2024
4d5faad
disable debug
micmelesse Jul 19, 2024
e76c5fb
pandas pdf
micmelesse Jul 19, 2024
9e7ce16
save
micmelesse Jul 19, 2024
4d1eeeb
set methods
micmelesse Jul 19, 2024
3a0cf22
record number of heads
micmelesse Jul 23, 2024
cd6bb74
use configs
micmelesse Jul 23, 2024
088fbc7
flexiable dim, n-heads, headofdim
micmelesse Jul 23, 2024
5ea6525
better benchmarking
micmelesse Jul 24, 2024
af3f4ee
basic inplace update working
micmelesse Jul 25, 2024
9d52279
works upto 64
micmelesse Jul 25, 2024
e0595e9
new_kv supported!
micmelesse Jul 25, 2024
8072212
test case for has_batch_idx
micmelesse Jul 25, 2024
62fdb92
has_batch_idx works!
micmelesse Jul 25, 2024
5f64128
save
micmelesse Jul 25, 2024
0bd947f
save
micmelesse Jul 25, 2024
75b5076
save
micmelesse Jul 26, 2024
23d08f1
save ref
micmelesse Jul 29, 2024
efefa81
fix mqa and gqa by duplicating
micmelesse Jul 30, 2024
77fc391
GQA and MQA working by kernel modifications
micmelesse Jul 30, 2024
8ea183b
fix new_kv with gqa
micmelesse Jul 30, 2024
5edf575
cache index
micmelesse Jul 30, 2024
f4f476d
deal with nans on fwd_splitk
micmelesse Jul 31, 2024
0e58a7c
save
micmelesse Jul 31, 2024
076f5fe
causal working on basic case
micmelesse Aug 1, 2024
1defaaf
causal works!
micmelesse Aug 1, 2024
9004132
alibi works!
micmelesse Aug 1, 2024
4b795dd
clean up
micmelesse Aug 1, 2024
ad6413c
clean prefill changes
micmelesse Aug 1, 2024
6415d9a
remove bwd stuff
micmelesse Aug 2, 2024
485ba55
limit decode test to test_op_fwd
micmelesse Aug 2, 2024
6b6e533
add ref
micmelesse Aug 5, 2024
e081f43
use bfloat
micmelesse Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ jobs:
- name: Build
run: |
python setup.py install
- name: Test
- name: AMD Kernel Tests
run: |
pytest flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd
pytest flash_attn/flash_attn_triton_kernel_prefill_amd.py
- name: Flash Attention Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_kvcache
pytest tests/test_flash_attn.py::test_flash_attn_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ var/
# Dev
venv

# Other
# AMD
.eggs
.vscode
core
scripts
log*
*csv
321 changes: 207 additions & 114 deletions benchmarks/benchmark_flash_attention.py

Large diffs are not rendered by default.

196 changes: 64 additions & 132 deletions flash_attn/flash_attn_triton_interface_amd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .flash_attn_triton_kernel_amd import MetaData, attention, get_shape_from_layout, _attn_bwd_preprocess, _attn_bwd
import torch
import triton
from .flash_attn_triton_kernel_prefill_amd import MetaData, attention_prefill, get_shape_from_layout, _attn_bwd_preprocess, _attn_bwd
from .flash_attn_triton_kernel_decode_amd import attention_decode

DEBUG=False
DEBUG = False

def fwd(q,
k,
Expand Down Expand Up @@ -31,7 +32,7 @@ def fwd(q,
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")
raise ValueError("dropout is not supported on AMD yet")

if o is None:
o = torch.empty_like(q)
Expand Down Expand Up @@ -60,7 +61,7 @@ def fwd(q,
input_metadata.check_args(q, k, v, o)

# Perform the forward attention computation
tri_out, encoded_softmax = attention(q, k, v, o, input_metadata)
tri_out, encoded_softmax = attention_prefill(q, k, v, o, input_metadata)

softmax_lse = encoded_softmax
softmax_p = encoded_softmax
Expand Down Expand Up @@ -93,9 +94,23 @@ def varlen_fwd(
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)
print("cu_seqlens_q:", cu_seqlens_q)
print("cu_seqlens_k:", cu_seqlens_k)
print("block_table_:", block_table_)
print("alibi_slopes:", alibi_slopes)
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("dropout_p:", dropout_p)
print("softmax_scale:", softmax_scale)
print("zero_tensors:", zero_tensors)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("return_softmax:", return_softmax)
print("gen_:", gen_)

if dropout_p != 0.0:
raise ValueError("dropout is not supported on HIP")
raise ValueError("dropout is not supported on AMD yet")

if o is None:
o = torch.empty_like(q)
Expand Down Expand Up @@ -123,14 +138,14 @@ def varlen_fwd(
input_metadata.check_args(q, k, v, o)

# Perform the forward attention computation
tri_out, encoded_softmax = attention(q, k, v, o, input_metadata)
tri_out, encoded_softmax = attention_prefill(q, k, v, o, input_metadata)

softmax_lse = encoded_softmax
softmax_p = encoded_softmax

return tri_out, q , k , v, o, softmax_lse, softmax_p, torch.get_rng_state()

def fwd_kvcache(
def fwd_kvcache(
q,
k_cache,
v_cache,
Expand All @@ -149,150 +164,67 @@ def fwd_kvcache(
window_size_right,
rotary_interleaved,
num_splits):
pass


def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left,
window_size_right, deterministic, gen_, rng_state):

if DEBUG:
print("flash_attn_triton_amd.py::bwd")
print("q:", q.shape)
print("k:", k.shape)
print("v:", v.shape)
print("softmax_lse:", softmax_lse)
print("dq:", dq.shape)
print("dk:", dk.shape)
print("dv:", dv.shape)
print()
print("flash_attn_triton_amd.py::fwd_kvcache")
print("q:", q, q.shape)
print("k_cache:", k_cache, k_cache.shape)
print("v_cache:", v_cache, v_cache.shape)
print("k:", k, k.shape if k is not None else None)
print("v:", v, v.shape if v is not None else None)
print("cache_seqlens:", cache_seqlens, cache_seqlens.size())
print("rotary_cos:", rotary_cos)
print("rotary_sin:", rotary_sin)
print("cache_batch_idx:", cache_batch_idx)
print("block_table:", block_table, block_table.shape if block_table is not None else None)
print("alibi_slopes:", alibi_slopes)
print("dropout_p:", dropout_p)
print("out:", out)
print("softmax_scale:", softmax_scale)
print("causal:", causal)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("deterministic:", deterministic)
print("gen_:", gen_)
print("rng_state:", rng_state)

print("rotary_interleaved:", rotary_interleaved)
print("num_splits:", num_splits)

if out is None:
out = torch.empty_like(q)

# Ensure the tensors have requires_grad=True
q.requires_grad_()
k.requires_grad_()
v.requires_grad_()
out.requires_grad_()

# Create metadata object
metadata = MetaData(sm_scale=softmax_scale)
metadata.max_seqlens_q = q.shape[1]
metadata.max_seqlens_k = k.shape[1]
metadata.layout = "bshd"
# fill metadata
input_metadata = MetaData(sm_scale=softmax_scale)
input_metadata.layout = "bshd"
input_metadata.max_seqlens_q = q.shape[1]
input_metadata.max_seqlens_k = k_cache.shape[1]
input_metadata.cache_seqlens = cache_seqlens
input_metadata.cache_batch_idx = cache_batch_idx

if metadata == 'bshd':
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
if k is not None and v is not None:
input_metadata.new_kv = True
input_metadata.seqlen_new = k.shape[1]
input_metadata.k_new = k
input_metadata.v_new = v

batch = q.shape[0]
nheads_q = q.shape[1]
BLOCK_DMODEL = q.shape[3]

# Setup metadata
if causal:
metadata.need_causal()
input_metadata.need_causal()

# if bias is not None:
# metadata.need_bias(bias, q.shape[0], q.shape[1], q.shape[2], k.shape[2])

return_softmax = True
if alibi_slopes is not None:
metadata.need_alibi(alibi_slopes, batch, nheads_q)

if dropout_p > 0.0:
metadata.need_dropout(dropout_p, return_softmax)
batch, _ , nheads_q, _= q.shape
input_metadata.need_alibi(alibi_slopes, batch, nheads_q)

# Check arguments
metadata.check_args(q, k, v, out)

# write your own version backward
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) # this passed from
# launch kernel
tri_out = attention_decode(q, k_cache, v_cache, input_metadata)

if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
o = out
do = dout
sm_scale = softmax_scale
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
seqlen_q = q.shape[2]
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_CTX, N_HEAD = q.shape[:3]
PRE_BLOCK = 128
# NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (sm_scale * RCP_LN2)
if DEBUG:
print("N_CTX:", N_CTX)
# assert N_CTX % PRE_BLOCK == 0
print()
print("tri_out:", tri_out, tri_out.shape)

return tri_out, None

delta = torch.empty_like(M)
_, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
# padded_head = (Lk != ctx.BLOCK_DMODEL)
grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0])
_attn_bwd_preprocess[grid_preprocess](
o,
do,
delta,
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
do.stride(0),
do.stride(1),
do.stride(2),
do.stride(3),
seqlen_q,
head_dim=Lk,
BLOCK_M=BLOCK,
D_HEAD=BLOCK_DMODEL,
)
grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD)
_attn_bwd[grid](
q,
arg_k,
v,
sm_scale,
alibi_slopes,
do,
dq,
dk,
dv,
M,
delta,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
N_HEAD,
N_CTX,
BLOCK_DMODEL= BLOCK_DMODEL,
BLOCK_M1=BLOCK_M1,
BLOCK_N1=BLOCK_N1,
BLOCK_M2=BLOCK_M2,
BLOCK_N2=BLOCK_N2,
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,
USE_ALIBI=False if alibi_slopes is None else True,
)

return dq, dk, dv, None
def bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left,
window_size_right, deterministic, gen_, rng_state):
raise ValueError("bwd is not supported on AMD yet")


def varlen_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, *args, **kwargs):
pass
raise ValueError("varlen_bwd is not supported on AMD yet")
Loading