Commit 802b3d2
committed
Enable Fwd and Backward
Enable Fwd and Backward
Enable Fwd and Backward
Enable fwd and varlen_fwd on AMD (#63)
* flash_attn_func works
Compress
This is a combination of 12 commits.
add scripts
save
add our kernel
import our kernel
round trip
use bshd layout
figure out segfault
fix
show backward failure with prints
save backward work
run forward only
test smallest config on everything
add test
fix
remove pre commit
install triton
skip dropout
pin d
32 factor d
just run power of 2
remove timeout
run serially
clean up
clean up 2
* Varlen works
This is a combination of 6 commits.
save
some tests passing
enable more
enable everything
move around
alibi works
* keep interface and kernel seperate
* clean up
enable flash_attn_with_kvcache (#68)
* Compress kvcache work
This is a combination of 11 commits.
kvcache work
This is a combination of 4 commits.
kvcache is not supported
save
save decode
save
clean up merge
save cases
save
save
save
save
key mask on triton side
fix q size issue
test combos
save
* fix causal. use cache_seqlens
* clean and test what works
* some configs work on new_kv but fails on 1,8
* cache overwrite correct
* new_kv works more or less
* test local
* work on paged kv attention
* prefill paged attention
* fix has_batch_idx and skip local and rotatary emb
* save
* save
* save
* save
* handle new_kv when paged kv cache
* all except has_batch_idx works
* major options are green
* test all
* add tests
* save
* clean up
* minor clean up
* simplest config
* save debug true
* save
* refactor slightly
* save work
* need key masking
* force hip
* use is_hip
* save
* fix cache_seq_len issue
* work on new_kv
* pass new_kv data
* save
* benchmark fwd only
* disable debug
* pandas pdf
* save
* set methods
* record number of heads
* use configs
* flexiable dim, n-heads, headofdim
* better benchmarking
* basic inplace update working
* works upto 64
* new_kv supported!
* test case for has_batch_idx
* has_batch_idx works!
* save
* save
* save
* save ref
* fix mqa and gqa by duplicating
* GQA and MQA working by kernel modifications
* fix new_kv with gqa
* cache index
* deal with nans on fwd_splitk
* save
* causal working on basic case
* causal works!
* alibi works!
* clean up
* clean prefill changes
* remove bwd stuff
* limit decode test to test_op_fwd
* add ref
* use bfloat
Fixes after rebase
Fixes after rebase
rebase fixes
deal with kvcache failure
new run for branch
cancel-in-progress
fix varlen_fwd bug
enable packed layouts and all configs (#72)
Clean up for Upstream (#81)
* Clean
Clean
This is a combination of 4 commits.
clean 1
clean 2
clean more
match main
typo fix
* use is_hip()
* clean up more
* skip odd d only
* fix bug
* skip randomly
* use Flag
* update readme
* remove quantization
* remove bwd
* minor
* print
* remove verbose print
* qunatize zero's out the d stride
Enable Vanilla Bwd and Refactor (#86)
* Vanilla BWD
Vanilla BWD
This is a combination of 79 commits.
save test_flash_attn_output
use impl functions
pass layout
add ref
move arround impls
fix stride issue
save oai kernel
add baseline impl
save bwd kernel working
remove old impl
remove block_ptrs from bwd
pass padded dmodel and apply masking. the old test cases work but cases with small d don't work
save
save
more prints
rename to M to L
save
add notes
add old_bwd back
fa failure fails in kernels too
isolate new bwd and keep old bwd in place
clean up
softmax_lse doesnot match refernce
LOG flag
softmax_lse with LN2
move qk_scale to loop
pass ln2 to fwd
just print kernel input
test softmax output from forward
test exp_scores_triton
save all the ref
create ref USE_EXP2 path
return scores
mask scores when returning them. Basic impl test passes
scores and output match
show max_diff
return score needs to be adjusted as we find new maxes
all good outputs. old style RCP2 example
prep bwd_impl test
save
try openai
save
fix softmax_lse bug
test_op_bwd_impl starting to work!
new kernel. exp2 works but exp is faliing
fix bwd exp2
add m and n masks. small cases still don't work
match old and new kernel prints
compare old and new
print inputs
save
old kernel match on dv
dq works
compare to pytorch including softmax in forward
fix bwd impl bug
small sizes in bwd impl work
old bwd test pass. Moving on to kernel tests
dq, dk and dv are filled in place if given. Need to match cast to match fa
fix non bug
fix dv mismatch. use_exp2 was set to true in fwd
fix case up 128
refactor and clean up a bit more
issue is that dq and dk are not zeros
dq must be zeroed out
ignore segfaults
fa ref and my ref match!
all tests run
use tolerance 1e-3
we need to figure out preprocessing
save
clean up
save
test delta diff
move old impl out
new preprocess function
preprocessing_use_o flag
working _bwd_preprocess_use_p
basic cases pass
all green
fwd exp2 usage is done right before exp
* refactor
* refactor 2
* refactor 3
* fix bug
* try ci
* add flag
* rename to utils
* skip test_op_fwd_decode_int4_kv
* reduce head size
* try again
* go back to old head sizes
* Use Strides
Use Strides
This is a combination of 11 commits.
use strides in bwd
add layout test in forward
fix shape layout function
smaller tests
save
fix varlen error
no headsize passed to bwd
deal with varlen layout
save
save
save
save
* use gen scripts
* varlen fwd passing
* core fwd ref impl
* fix minor bugs
* wrap varlen- launcher attention_forward_pytorch_ref_impl
* varlen backward ref added
* add offsets for varlen
* fix delta bug
* varlen bwd working
* save
* runs on Mi200
* just test basics
* save
* fix bug
* fix varlen in64 bug
* add ref
* test_impl working with causal
* fix qkvpacked issue
* qkvpacked run tests
* remove test_backward
* save
* just test output
* dump into tensors
* softmaxlse layout for varlen
* small cases working
* bwd thd green. although maybe some oom
* forward out and lse are good. Something wrong with backward ref
* make varlen ref work
* save work, ref is working mostly
* 91 failed, 6542 passed, 6336 skipped, 1 warning
* ref is all green
* debug flag in utils
* found bad softmax_lse in varlen fwd
* fix bug in softmax lse. strides in varlen werenot right
* add causal tests and 32*32 bwd doesnot have segfault
* save
* fix oom by reducing block size for small heads
* bwd ref with causal working
* test impl
* causal test passes
* causal working
* fix tests
* nicer bench
* fix qvpacked error
* fix varlen qvpacked bug
* fix minor bug
* bench prefill and prefill_old using the same script
* autotune configs for fwd
* autotune flag
* clean up decode impl
* clean up
* clean up more
* bench everything by default and return time
* clean up readmes
REBASE: fix interface changes in rebase
rename test to test_flash_attn_triton_amd
REBASE: fix unpad diffs
minor clean up in setup
FLASH_ATTENTION_TRITON_AMD flags
bench fwd and bwd
fix sequence_parallel1 parent 34a3656 commit 802b3d2
10 files changed
Lines changed: 287 additions & 161 deletions
File tree
- .github/workflows
- flash_attn/flash_attn_triton_amd
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | | - | |
| 25 | + | |
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
164 | 164 | | |
165 | 165 | | |
166 | 166 | | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
| 167 | + | |
| 168 | + | |
176 | 169 | | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
| 170 | + | |
181 | 171 | | |
182 | | - | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
183 | 177 | | |
184 | | - | |
185 | | - | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
186 | 181 | | |
187 | | - | |
188 | | - | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
189 | 188 | | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | | - | |
| 189 | + | |
| 190 | + | |
194 | 191 | | |
195 | | - | |
196 | | - | |
197 | | - | |
| 192 | + | |
198 | 193 | | |
199 | | - | |
200 | 194 | | |
201 | | - | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
202 | 199 | | |
| 200 | + | |
203 | 201 | | |
204 | | - | |
205 | 202 | | |
206 | | - | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
207 | 207 | | |
208 | 208 | | |
| 209 | + | |
209 | 210 | | |
210 | 211 | | |
211 | 212 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
25 | 25 | | |
26 | 26 | | |
27 | 27 | | |
28 | | - | |
| 28 | + | |
29 | 29 | | |
30 | 30 | | |
31 | | - | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
32 | 35 | | |
33 | 36 | | |
34 | 37 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | | - | |
72 | | - | |
73 | | - | |
74 | | - | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
75 | 108 | | |
76 | 109 | | |
77 | 110 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
6 | | - | |
7 | | - | |
8 | | - | |
9 | | - | |
10 | | - | |
| 5 | + | |
11 | 6 | | |
12 | 7 | | |
13 | 8 | | |
| |||
89 | 84 | | |
90 | 85 | | |
91 | 86 | | |
| 87 | + | |
92 | 88 | | |
93 | 89 | | |
94 | 90 | | |
| |||
419 | 415 | | |
420 | 416 | | |
421 | 417 | | |
422 | | - | |
423 | | - | |
424 | | - | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
425 | 423 | | |
426 | 424 | | |
427 | 425 | | |
| |||
600 | 598 | | |
601 | 599 | | |
602 | 600 | | |
603 | | - | |
604 | | - | |
605 | | - | |
606 | | - | |
607 | | - | |
608 | | - | |
| 601 | + | |
609 | 602 | | |
610 | 603 | | |
611 | 604 | | |
| |||
656 | 649 | | |
657 | 650 | | |
658 | 651 | | |
| 652 | + | |
| 653 | + | |
659 | 654 | | |
660 | 655 | | |
661 | 656 | | |
| |||
687 | 682 | | |
688 | 683 | | |
689 | 684 | | |
690 | | - | |
691 | | - | |
692 | 685 | | |
693 | | - | |
694 | | - | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
695 | 710 | | |
696 | | - | |
| 711 | + | |
697 | 712 | | |
698 | 713 | | |
699 | 714 | | |
| |||
798 | 813 | | |
799 | 814 | | |
800 | 815 | | |
801 | | - | |
| 816 | + | |
802 | 817 | | |
803 | 818 | | |
804 | 819 | | |
805 | | - | |
806 | | - | |
807 | | - | |
| 820 | + | |
808 | 821 | | |
809 | | - | |
810 | | - | |
811 | | - | |
812 | | - | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
| 831 | + | |
| 832 | + | |
| 833 | + | |
| 834 | + | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
813 | 838 | | |
814 | | - | |
| |||
0 commit comments