Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit ce1670b

Browse files
bs/seq bucketing for prompt and decode (#33)
* Bucketing/Warmup WIP * Cleanup * Revert "Fix model_output_idx on HPU (#27)" This reverts commit 90dfa92. * Rework selected_token_indices fix to also work with block_size padding * Simple prompt attention POC * Remove cumsum * MQA/GQA support for simple prompt_attention * Cleanup * Fix typo * Restore profiling runs
1 parent 2664659 commit ce1670b

File tree

5 files changed

+225
-763
lines changed

5 files changed

+225
-763
lines changed

vllm/attention/backends/habana_attn.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
33
###############################################################################
44

5-
import importlib
65
from dataclasses import dataclass
76
from typing import Dict, List, Optional, Tuple, Type
87

98
import torch
9+
import math
1010
import vllm.hpu.xops as xops
1111
from vllm.hpu.attn_bias import (AttentionBias,
12-
BlockDiagonalCausalMask,
1312
LowerTriangularMaskWithTensorBias)
1413

1514
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -18,7 +17,6 @@
1817
from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention,
1918
HabanaPagedAttentionMetadata)
2019
from vllm.logger import init_logger
21-
from vllm.utils import is_hip
2220

2321
logger = init_logger(__name__)
2422

@@ -119,11 +117,11 @@ def __post_init__(self):
119117
class HabanaAttentionImpl(AttentionImpl):
120118
"""
121119
If the input tensors contain prompt tokens, the layout is as follows:
122-
|<--------------- num_prefill_tokens ----------------->|
120+
|<--------------- num_prefill_tokens ----------------->|
123121
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
124122
125-
Otherwise, the layout is as follows:
126-
|<----------------- num_decode_tokens ------------------>|
123+
Otherwise, the layout is as follows:
124+
|<----------------- num_decode_tokens ------------------>|
127125
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
128126
129127
Generation tokens can contain padding when cuda-graph is used.
@@ -196,48 +194,37 @@ def forward(
196194
HabanaPagedAttention.write_to_paged_cache(key, value, key_cache,
197195
value_cache,
198196
attn_metadata.slot_mapping,
199-
attn_metadata.kv_cache_dtype,
197+
attn_metadata.kv_cache_dtype,
200198
attn_metadata.prefill_metadata is not None)
201199

202200
if prefill_meta := attn_metadata.prefill_metadata:
203201
# Prompt run.
204202
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
205-
# normal attention.
206-
# block tables are empty if the prompt does not have a cached
207-
# prefix.
208-
if self.num_kv_heads != self.num_heads:
209-
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
210-
# project the key and value tensors to the desired number of
211-
# heads.
212-
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
213-
query = query.view(query.shape[0], self.num_kv_heads,
214-
self.num_queries_per_kv,
215-
query.shape[-1])
216-
key = key[:, :,
217-
None, :].expand(key.shape[0], self.num_kv_heads,
218-
self.num_queries_per_kv,
219-
key.shape[-1])
220-
value = value[:, :,
221-
None, :].expand(value.shape[0],
222-
self.num_kv_heads,
223-
self.num_queries_per_kv,
224-
value.shape[-1])
225-
203+
# TODO: move this outside of model
226204
if prefill_meta.attn_bias is None:
227205
if self.alibi_slopes is None:
228-
attn_bias = BlockDiagonalCausalMask.from_seqlens(
229-
[seq_len] * batch_size)
206+
lens = torch.tensor(attn_metadata.prefill_metadata.seq_lens, device=query.device, dtype=torch.int32)
207+
len_mask = (torch.arange(0, seq_len, device=query.device, dtype=torch.int32)
208+
.view(1, seq_len)
209+
.ge(lens.unsqueeze(-1))
210+
.view(batch_size, 1, 1, seq_len))
211+
causal_mask = torch.triu(
212+
torch.ones((batch_size, 1, seq_len, seq_len), device=query.device, dtype=torch.bool),
213+
diagonal=1
214+
)
215+
mask = causal_mask.logical_or(len_mask)
216+
attn_bias = (torch.zeros_like(mask, dtype=query.dtype)
217+
.masked_fill_(mask, -math.inf))
230218
if self.sliding_window is not None:
231-
attn_bias = attn_bias.make_local_attention(
232-
self.sliding_window)
219+
raise NotImplementedError("Sliding window is not supported on HPU")
233220
prefill_meta.attn_bias = attn_bias
234221
else:
235222
prefill_meta.attn_bias = _make_alibi_bias(
236223
self.alibi_slopes, self.num_kv_heads, batch_size,
237224
seq_len, query.dtype)
238-
query_shape = (batch_size, seq_len, self.num_kv_heads, self.num_queries_per_kv, self.head_size) if self.num_kv_heads != self.num_heads else (batch_size, seq_len, self.num_heads, self.head_size)
239-
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.num_queries_per_kv, self.head_size) if self.num_kv_heads != self.num_heads else (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
240-
out = xops.memory_efficient_attention_forward(
225+
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
226+
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size)
227+
out = xops.prompt_attention(
241228
query.view(query_shape),
242229
key.view(kv_shape),
243230
value.view(kv_shape),

vllm/hpu/xops.py

Lines changed: 30 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,62 +5,37 @@
55
# LICENSE file in the root directory of this source tree.
66
###############################################################################
77

8-
import habana_frameworks.torch as htorch
98
import torch
10-
import torch.nn.functional as F
11-
from typing import List, Optional, Tuple, Union
12-
from .attn_bias import AttentionBias, BlockDiagonalCausalMask
9+
from typing import Optional
1310

14-
try:
15-
from habana_frameworks.torch.hpex.kernels import FusedSDPA
16-
except ImportError:
17-
print("Not using HPU fused scaled dot-product attention kernel.")
18-
FusedSDPA = None
11+
import vllm.hpu.utils
1912

20-
def memory_efficient_attention_forward(
21-
query: torch.Tensor,
22-
key: torch.Tensor,
23-
value: torch.Tensor,
24-
attn_bias: Optional[torch.Tensor] = None,
25-
p: float = 0.0,
26-
scale: Optional[float] = None,
27-
) -> torch.Tensor:
28-
assert attn_bias is not None, "Attention mask is required for prompt processing"
29-
dim = query.dim()
30-
is_causal = isinstance(attn_bias, BlockDiagonalCausalMask)
31-
if FusedSDPA and (is_causal or attn_bias is None):
32-
bs = query.shape[0]
33-
seq_len_q = query.shape[1]
34-
seq_len_kv = key.shape[1]
35-
heads = query.shape[-2] if dim != 5 else query.shape[-3]
36-
attn_groups = 1 if dim != 5 else query.shape[-2]
37-
head_dim = query.shape[-1]
38-
if dim == 4:
39-
# [bs, seq_len, 1, heads, head_dim] -> [bs, heads, seq_len, head_dim]
40-
query = query.reshape(bs, seq_len_q, heads, head_dim).permute(0, 2, 1, 3)
41-
key = key.reshape(bs, seq_len_kv, heads, head_dim).permute(0, 2, 1, 3)
42-
value = value.reshape(bs, seq_len_kv, heads, head_dim).permute(0, 2, 1, 3)
43-
elif dim == 5:
44-
# [bs, seq_len, heads, attn_groups, head_dim] -> [bs, heads, attn_groups, seq_len, head_dim]
45-
query = query.reshape(bs, seq_len_q, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4)
46-
key = key.reshape(bs, seq_len_kv, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4)
47-
value = value.reshape(bs, seq_len_kv, heads, attn_groups, head_dim).permute(0, 2, 3, 1, 4)
48-
else:
49-
raise ValueError(f"Unsupported attention dimension: {dim}")
50-
51-
import habana_frameworks.torch.hpu as ht
52-
with ht.sdp_kernel(enable_recompute=False): # (flash_attention_recompute and q_len == 1)):
53-
out = FusedSDPA.apply(
54-
query, key, value, None, p, is_causal, scale
55-
)
56-
htorch.core.mark_step()
57-
if dim == 4:
58-
# [bs, heads, seq_len, head_dim] -> [bs, seq_len, heads, head_dim]
59-
out = out.permute(0, 2, 1, 3).reshape(bs, seq_len_q, heads, head_dim)
60-
elif dim == 5:
61-
# [bs, heads, attn_groups, seq_len, head_dim] -> [bs, seq_len, heads, attn_groups, head_dim]
62-
out = out.permute(0, 3, 1, 2, 4).reshape(bs, seq_len_q, heads, attn_groups, head_dim)
63-
else:
64-
raise NotImplementedError(f'Only FusedSDPA causal or non-masked attention is supported.\nFusedSDPA support: {FusedSDPA is not None}\nis_causal: {is_causal}\nmask_present: {attn_bias is not None}')
6513

66-
return out
14+
@vllm.hpu.utils.with_mark_steps
15+
def prompt_attention(
16+
query: torch.Tensor,
17+
key: torch.Tensor,
18+
value: torch.Tensor,
19+
attn_bias: Optional[torch.Tensor] = None,
20+
p: float = 0.0,
21+
scale: Optional[float] = None,
22+
) -> torch.Tensor:
23+
query = query.transpose(1, 2)
24+
key = key.transpose(1, 2)
25+
value = value.transpose(1, 2)
26+
query_heads = query.size(1)
27+
kv_heads = key.size(1)
28+
if query_heads != kv_heads:
29+
query = query.unflatten(1, (kv_heads, -1))
30+
key = key.unflatten(1, (kv_heads, 1))
31+
value = value.unflatten(1, (kv_heads, 1))
32+
attn_bias = attn_bias.unsqueeze(2)
33+
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
34+
if attn_bias is not None:
35+
attn_weights.add_(attn_bias)
36+
attn_weights = torch.softmax(attn_weights, dim=-1)
37+
attn_weights = torch.matmul(attn_weights, value)
38+
if query_heads != kv_heads:
39+
attn_weights = attn_weights.flatten(1, 2)
40+
attn_weights = attn_weights.transpose(1, 2)
41+
return attn_weights

vllm/model_executor/sampling_metadata.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,6 @@ def _prepare_seq_groups(
192192
# Total number of prompts from given sequence groups.
193193
num_prompts = 0
194194

195-
# FIXME: On HPU prompts are right-padded. We need to take that into account
196-
# when updating model_output_idx
197-
if is_hpu() and len(seq_lens) > 0:
198-
assert seq_lens == query_lens, 'Prompt chunking is not yet supported on HPU!'
199-
max_seq_len = max(seq_lens)
200-
201195
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
202196
seq_ids = list(seq_group_metadata.seq_data.keys())
203197
sampling_params = seq_group_metadata.sampling_params
@@ -225,12 +219,10 @@ def _prepare_seq_groups(
225219
prompt_logprob_len = (query_len - num_prefill_sample
226220
if do_sample else query_len)
227221
sample_len = num_prefill_sample if do_sample else 0
228-
padding_len = 0 if not is_hpu() else max_seq_len - seq_len
229222
else:
230223
# Decode
231224
prompt_logprob_len = 0
232225
sample_len = len(seq_ids) if do_sample else 0
233-
padding_len = 0
234226

235227
# Update indices to select from the model output.
236228
"""
@@ -249,7 +241,6 @@ def _prepare_seq_groups(
249241
selected_token_indices.extend(
250242
range(model_output_idx, model_output_idx + sample_len))
251243
model_output_idx += sample_len
252-
model_output_idx += padding_len
253244

254245
# We now find indices for logprob computation and sampling.
255246
"""

0 commit comments

Comments
 (0)