Skip to content

Commit 0ab278c

Browse files
authored
[Core] Remove unnecessary copies in flash attn backend (#5138)
1 parent 7a64d24 commit 0ab278c

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

requirements-cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ ray >= 2.9
66
nvidia-ml-py # for pynvml package
77
torch == 2.3.0
88
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
9-
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0
9+
vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0

vllm/attention/backends/flash_attn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def forward(
317317
# normal attention
318318
# When block_tables are not filled, it means q and k are the
319319
# prompt, and they have the same length.
320-
out = flash_attn_varlen_func(
320+
flash_attn_varlen_func(
321321
q=query,
322322
k=key,
323323
v=value,
@@ -329,14 +329,13 @@ def forward(
329329
causal=True,
330330
window_size=self.sliding_window,
331331
alibi_slopes=self.alibi_slopes,
332+
out=output[:num_prefill_tokens],
332333
)
333-
assert output[:num_prefill_tokens].shape == out.shape
334-
output[:num_prefill_tokens] = out
335334
else:
336335
# prefix-enabled attention
337336
assert prefill_meta.seq_lens is not None
338337
max_seq_len = max(prefill_meta.seq_lens)
339-
output[:num_prefill_tokens] = flash_attn_varlen_func(
338+
flash_attn_varlen_func(
340339
q=query,
341340
k=key_cache,
342341
v=value_cache,
@@ -348,11 +347,12 @@ def forward(
348347
causal=True,
349348
alibi_slopes=self.alibi_slopes,
350349
block_table=prefill_meta.block_tables,
350+
out=output[:num_prefill_tokens],
351351
)
352352

353353
if decode_meta := attn_metadata.decode_metadata:
354354
# Decoding run.
355-
output[num_prefill_tokens:] = flash_attn_with_kvcache(
355+
flash_attn_with_kvcache(
356356
decode_query.unsqueeze(1),
357357
key_cache,
358358
value_cache,
@@ -361,7 +361,8 @@ def forward(
361361
softmax_scale=self.scale,
362362
causal=True,
363363
alibi_slopes=self.alibi_slopes,
364-
).squeeze(1)
364+
out=output[num_prefill_tokens:].unsqueeze(1),
365+
)
365366

366367
# Reshape the output tensor.
367368
return output.view(num_tokens, hidden_size)

0 commit comments

Comments
 (0)