Skip to content

Commit 802b3d2

Browse files
committed
Enable Fwd and Backward
Enable Fwd and Backward Enable Fwd and Backward Enable fwd and varlen_fwd on AMD (#63) * flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up enable flash_attn_with_kvcache (#68) * Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat Fixes after rebase Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug enable packed layouts and all configs (#72) Clean up for Upstream (#81) * Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride Enable Vanilla Bwd and Refactor (#86) * Vanilla BWD Vanilla BWD This is a combination of 79 commits. save test_flash_attn_output use impl functions pass layout add ref move arround impls fix stride issue save oai kernel add baseline impl save bwd kernel working remove old impl remove block_ptrs from bwd pass padded dmodel and apply masking. the old test cases work but cases with small d don't work save save more prints rename to M to L save add notes add old_bwd back fa failure fails in kernels too isolate new bwd and keep old bwd in place clean up softmax_lse doesnot match refernce LOG flag softmax_lse with LN2 move qk_scale to loop pass ln2 to fwd just print kernel input test softmax output from forward test exp_scores_triton save all the ref create ref USE_EXP2 path return scores mask scores when returning them. Basic impl test passes scores and output match show max_diff return score needs to be adjusted as we find new maxes all good outputs. old style RCP2 example prep bwd_impl test save try openai save fix softmax_lse bug test_op_bwd_impl starting to work! new kernel. exp2 works but exp is faliing fix bwd exp2 add m and n masks. small cases still don't work match old and new kernel prints compare old and new print inputs save old kernel match on dv dq works compare to pytorch including softmax in forward fix bwd impl bug small sizes in bwd impl work old bwd test pass. Moving on to kernel tests dq, dk and dv are filled in place if given. Need to match cast to match fa fix non bug fix dv mismatch. use_exp2 was set to true in fwd fix case up 128 refactor and clean up a bit more issue is that dq and dk are not zeros dq must be zeroed out ignore segfaults fa ref and my ref match! all tests run use tolerance 1e-3 we need to figure out preprocessing save clean up save test delta diff move old impl out new preprocess function preprocessing_use_o flag working _bwd_preprocess_use_p basic cases pass all green fwd exp2 usage is done right before exp * refactor * refactor 2 * refactor 3 * fix bug * try ci * add flag * rename to utils * skip test_op_fwd_decode_int4_kv * reduce head size * try again * go back to old head sizes * Use Strides Use Strides This is a combination of 11 commits. use strides in bwd add layout test in forward fix shape layout function smaller tests save fix varlen error no headsize passed to bwd deal with varlen layout save save save save * use gen scripts * varlen fwd passing * core fwd ref impl * fix minor bugs * wrap varlen- launcher attention_forward_pytorch_ref_impl * varlen backward ref added * add offsets for varlen * fix delta bug * varlen bwd working * save * runs on Mi200 * just test basics * save * fix bug * fix varlen in64 bug * add ref * test_impl working with causal * fix qkvpacked issue * qkvpacked run tests * remove test_backward * save * just test output * dump into tensors * softmaxlse layout for varlen * small cases working * bwd thd green. although maybe some oom * forward out and lse are good. Something wrong with backward ref * make varlen ref work * save work, ref is working mostly * 91 failed, 6542 passed, 6336 skipped, 1 warning * ref is all green * debug flag in utils * found bad softmax_lse in varlen fwd * fix bug in softmax lse. strides in varlen werenot right * add causal tests and 32*32 bwd doesnot have segfault * save * fix oom by reducing block size for small heads * bwd ref with causal working * test impl * causal test passes * causal working * fix tests * nicer bench * fix qvpacked error * fix varlen qvpacked bug * fix minor bug * bench prefill and prefill_old using the same script * autotune configs for fwd * autotune flag * clean up decode impl * clean up * clean up more * bench everything by default and return time * clean up readmes REBASE: fix interface changes in rebase rename test to test_flash_attn_triton_amd REBASE: fix unpad diffs minor clean up in setup FLASH_ATTENTION_TRITON_AMD flags bench fwd and bwd fix sequence_parallel
1 parent 34a3656 commit 802b3d2

10 files changed

Lines changed: 287 additions & 161 deletions

File tree

.github/workflows/amd_tests.yml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
name: AMD Perf Kernel Tests
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
branches: [main_perf]
7+
merge_group:
8+
branches: [main_perf]
9+
types: [checks_requested]
10+
push:
11+
branches: [main_perf, micmelesse/upstream_pr]
12+
13+
concurrency:
14+
group: ${{ github.ref }}
15+
cancel-in-progress: true
16+
17+
permissions: read-all
18+
19+
jobs:
20+
Runner-Preparation-AMD:
21+
runs-on: ubuntu-latest
22+
timeout-minutes: 30
23+
outputs:
24+
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
25+
steps:
26+
- name: Prepare runner matrix
27+
id: set-matrix
28+
run: |
29+
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
30+
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
31+
else
32+
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
33+
fi
34+
35+
Integration-Tests-AMD:
36+
needs: Runner-Preparation-AMD
37+
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
38+
runs-on: ${{ matrix.runner }}
39+
strategy:
40+
matrix:
41+
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
42+
container:
43+
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
44+
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
45+
steps:
46+
- name: Checkout
47+
uses: actions/checkout@v4
48+
- name: Install Triton
49+
run: |
50+
pip uninstall -y triton
51+
pip install matplotlib pandas pytest
52+
git clone https://github.com/triton-lang/triton
53+
cd triton
54+
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
55+
pip install --verbose -e python
56+
cd ..
57+
- name: Build
58+
run: |
59+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
60+
python setup.py install
61+
- name: Flash Attention Tests Using Reference Impl
62+
run: |
63+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
64+
export FLASH_ATTENTION_TRITON_AMD_REF=1
65+
pytest tests/test_flash_attn_triton_amd.py
66+
- name: Flash Attention Tests
67+
run: |
68+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
69+
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
70+
pytest tests/test_flash_attn_triton_amd.py
71+
- name: AMD Kernel Tests
72+
run: |
73+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
74+
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
75+
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
76+
- name: AMD Kernel Bench
77+
run: |
78+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
79+
python flash_attn/flash_attn_triton_amd/bench.py

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@ var/
2222
*.egg-info/
2323
.installed.cfg
2424
*.egg
25-
.eggs/
25+
.eggs
2626

2727
# IDE-related
2828
.idea/
2929

3030
# Dev
3131
venv
32+
scripts
33+
*.log
34+
core.*
35+
*.csv
36+
*.png
37+
*.html

README.md

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -164,48 +164,49 @@ git checkout main_perf
164164
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
165165
```
166166

167-
To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing.
168-
```
169-
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
170-
```
171-
172-
You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
173-
```
174-
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
175-
```
167+
#### Triton Backend
168+
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
176169

177-
###### Docker
178-
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
179-
```
180-
FROM rocm/pytorch:latest
170+
It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
181171

182-
WORKDIR /workspace
172+
These features are supported in Fwd and Bwd
173+
1) Fwd and Bwd with causal masking
174+
2) Variable sequence lengths
175+
3) Arbitrary Q and KV sequence lengths
176+
4) Arbitrary head sizes
183177

184-
# install triton
185-
RUN pip install triton==3.2.0
178+
These features are supported in Fwd for now. We will add them to backward soon.
179+
1) Multi and grouped query attention
180+
2) ALiBi and matrix bias
186181

187-
# install flash attention
188-
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
182+
These features are in development
183+
1) Paged Attention
184+
2) Sliding Window
185+
3) Rotary embeddings
186+
4) Dropout
187+
5) Performance Improvements
189188

190-
RUN git clone https://github.com/ROCm/flash-attention.git &&\
191-
cd flash-attention &&\
192-
git checkout main_perf &&\
193-
python setup.py install
189+
#### Getting Started
190+
To get started with the triton backend for AMD, follow the steps below.
194191

195-
# set working dir
196-
WORKDIR /workspace/flash-attention
197-
```
192+
First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).
198193

199-
To build the docker file
200194
```
201-
docker build -t fa_triton .
195+
git clone https://github.com/triton-lang/triton
196+
cd triton
197+
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
198+
pip install --verbose -e python
202199
```
200+
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
203201

204-
To run the docker image
205202
```
206-
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
203+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
204+
cd flash-attention
205+
python setup.py install
206+
pytest tests/test_flash_attn.py
207207
```
208208

209+
209210
## How to use FlashAttention
210211

211212
The main functions implement scaled dot product attention (softmax(Q @ K^T *

flash_attn/flash_attn_triton_amd/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ We are working on the following things
2525
##### Getting Started
2626
To get started with the triton backend for AMD, follow the steps below.
2727

28-
First install the recommended Triton version
28+
First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).
2929

3030
```
31-
pip install triton==3.2.0
31+
git clone https://github.com/triton-lang/triton
32+
cd triton
33+
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
34+
pip install --verbose -e python
3235
```
3336
Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
3437

flash_attn/flash_attn_triton_amd/bench.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,53 @@
5858
"flash_attn_with_kvcache": ["ck", "triton"],
5959
}
6060

61-
VALID_MODES = ['fwd', 'bwd', 'full']
62-
SUPPORTED_MODES = {
63-
"flash_attn_func": ["fwd", "bwd", "full"],
64-
"flash_attn_fp8_func": ["fwd", "bwd", "full"],
65-
"flash_attn_kvpacked_func": ["fwd", "bwd", "full"],
66-
"flash_attn_qkvpacked_func": ["fwd", "bwd", "full"],
67-
"flash_attn_qkvpacked_fp8_func": ["fwd", "bwd", "full"],
68-
"flash_attn_varlen_func": ["fwd", "bwd", "full"],
69-
"flash_attn_varlen_fp8_func": ["fwd", "bwd", "full"],
70-
"flash_attn_varlen_kvpacked_func": ["fwd", "bwd", "full"],
71-
"flash_attn_varlen_qkvpacked_func": ["fwd", "bwd", "full"],
72-
"flash_attn_varlen_qkvpacked_fp8_func": ["fwd", "bwd", "full"],
73-
"flash_attn_with_kvcache": ["fwd"],
74-
}
61+
def get_benchmark_configs(args, varlen=False):
62+
"""
63+
Returns benchmark configurations based on whether variable-length sequences are used.
64+
"""
65+
if args.custom_config:
66+
hk = args.hq if not args.hk else args.hk
67+
sk = args.sq if not args.sk else args.sk
68+
return [(args.b, args.hq, hk, args.sq, sk)]
69+
elif varlen:
70+
return [
71+
(2, 16, 4, 1024, 1024),
72+
(8, 16, 2, 2048, 2048),
73+
(4, 16, 8, 4096, 4096),
74+
(2, 16, 4, 8192, 8192),
75+
(2, 16, 8, 16384, 16384),
76+
(2, 48, 12, 1024, 1024),
77+
(2, 48, 24, 2048, 2048),
78+
(2, 48, 8, 4096, 4096),
79+
(2, 48, 4, 8192, 8192),
80+
(2, 48, 2, 16384, 16384),
81+
(2, 64, 32, 1024, 1024),
82+
(4, 64, 16, 2048, 2048),
83+
(4, 64, 8, 4096, 4096),
84+
(4, 64, 32, 8192, 8192),
85+
(4, 128, 16, 16384, 16384),
86+
]
87+
else:
88+
return [
89+
(16, 16, 16, 1024, 1024),
90+
(8, 16, 16, 2048, 2048),
91+
(4, 16, 16, 4096, 4096),
92+
(2, 16, 16, 8192, 8192),
93+
(1, 16, 16, 16384, 16384),
94+
(2, 48, 48, 1024, 1024),
95+
(2, 48, 48, 2048, 1024),
96+
(2, 48, 48, 4096, 8192),
97+
(2, 48, 48, 8192, 4096),
98+
(2, 48, 48, 16384, 8192),
99+
(8, 16, 16, 1989, 15344),
100+
(4, 16, 16, 4097, 163),
101+
(2, 16, 16, 8122, 2159),
102+
(1, 16, 16, 16281, 7),
103+
(2, 48, 48, 1021, 1020),
104+
(2, 48, 48, 2001, 2048),
105+
(2, 48, 48, 3996, 9639),
106+
(2, 48, 48, 8181, 1021),
107+
]
75108

76109
@dataclass
77110
class EnvVariableConfig:

flash_attn/flash_attn_triton_amd/bwd_prefill.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22
import torch
33
import triton
44
import triton.language as tl
5-
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_fp8, write_dropout_mask, create_dropout_mask
6-
7-
# TODO: move this into utils.py so it's shared among kernels
8-
# NOTE: triton fails to import tl.constexprs so create them here for the file
9-
tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH)
10-
tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP)
5+
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG
116

127
@triton.jit
138
def _bwd_preprocess(
@@ -89,6 +84,7 @@ def _bwd_preprocess(
8984
tl.store(delta_ptrs, delta, mask=mask_m)
9085

9186

87+
9288
@triton.jit
9389
def _bwd_kernel_one_col_block(
9490
Q,
@@ -419,9 +415,11 @@ def _bwd_kernel(
419415
l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam
420416
delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam
421417

422-
if DROPOUT:
423-
batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm
424-
dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm
418+
# output tensor offsets
419+
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
420+
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
421+
if SEQUENCE_PARALLEL:
422+
dq_offset = DQ + stride_dq_all * start_n + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
425423
else:
426424
batch_philox_offset = 0
427425
dropout_offset = 0
@@ -600,12 +598,7 @@ def attention_prefill_backward_triton_impl(
600598
philox_seed: Optional[int],
601599
philox_offset: Optional[int],
602600
use_exp2: bool,
603-
sequence_parallel: bool = True,
604-
# fp8
605-
descale_q: Optional[torch.Tensor] = None,
606-
descale_k: Optional[torch.Tensor] = None,
607-
descale_v: Optional[torch.Tensor] = None,
608-
descale_do: Optional[torch.Tensor] = None,
601+
sequence_parallel = False,
609602
):
610603
if DEBUG:
611604
print()
@@ -656,6 +649,8 @@ def attention_prefill_backward_triton_impl(
656649
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
657650
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
658651
stride_oz, stride_oh, stride_om, stride_ok = o_strides
652+
stride_dq_all = q.numel()
653+
batch_headsize = batch * nheads_q
659654
is_varlen = layout == "thd"
660655
group_size = nheads_q // nheads_k
661656
use_dropout = (dropout_p > 0.0)
@@ -687,13 +682,33 @@ def attention_prefill_backward_triton_impl(
687682
ACTUAL_BLOCK_DMODEL = head_size
688683

689684
do = do.contiguous()
690-
691-
# deal with dq
692685
if sequence_parallel:
693-
dq = dq.unsqueeze(0).repeat(num_blocks_n, *([1] * len(q.shape))) # we do repeat instead of expand because we need to write data so views are not enough
694-
stride_dq_all = dq.stride()[0]
686+
# replicate q for each parallel sequence
687+
replicas = num_blocks_n
688+
dq_shape = (replicas,) + q.shape
689+
else:
690+
dq_shape = q.shape
691+
692+
is_qkvpacked = False
693+
if dq is None or dk is None or dv is None:
694+
dq = torch.zeros(dq_shape, device=q.device, dtype=q.dtype)
695+
dk = torch.empty_like(k)
696+
dv = torch.empty_like(v)
697+
elif (not dq.is_contiguous()) or (not dq.is_contiguous()) or (not dq.is_contiguous()):
698+
if DEBUG:
699+
print("Not contigious and setting is packed to True")
700+
is_qkvpacked = True
701+
dq_og = dq
702+
dq = dq.contiguous()
703+
dk_og = dk
704+
dk = dk.contiguous()
705+
dv_og = dv
706+
dv = dv.contiguous()
707+
708+
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
709+
dq.zero_()
695710

696-
# assert contiguous
711+
# assert contigious
697712
assert do.is_contiguous()
698713
assert q.is_contiguous()
699714
assert k.is_contiguous()
@@ -798,17 +813,26 @@ def attention_prefill_backward_triton_impl(
798813
FP8_MAX=FP8_MAX
799814
)
800815

801-
if sequence_parallel:
816+
if len(dq.shape) == 5:
802817
dq = dq.sum(dim=0)
803818

804819
if DEBUG:
805-
print("attention_prefill_backward_triton_impl outputs")
806-
print("dv:", dv, dv.shape)
807-
print("dk:", dk, dk.shape)
820+
print("_bwd_kernel outputs")
808821
print("dq:", dq, dq.shape)
809-
if use_dropout:
810-
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
811-
print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
812-
write_dropout_mask(dropout_mask, "dropout_mask_bwd")
822+
print("dk:", dk, dk.shape)
823+
print("dv:", dv, dv.shape)
824+
print("delta:", delta, delta.shape)
825+
826+
if is_qkvpacked:
827+
if DEBUG:
828+
print("Copying back to original tensors due to ispacked")
829+
830+
# copy back results to og tensors
831+
dq_og.copy_(dq)
832+
dk_og.copy_(dk)
833+
dv_og.copy_(dv)
834+
return dq_og, dk_og, dv_og, delta, None, None
835+
else:
836+
return dq, dk, dv, delta, None, None
837+
813838

814-
return delta

0 commit comments

Comments
 (0)