Skip to content

Commit 7f58ad1

Browse files
Add support for various softmax normalization options (vllm-project#420)
Supporting PR for HabanaAI/vllm-hpu-extension#14
1 parent 892c090 commit 7f58ad1

4 files changed

Lines changed: 11 additions & 2 deletions

File tree

requirements-hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ pandas
88
tabulate
99
setuptools>=61
1010
setuptools-scm>=8
11-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6
11+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@c2801bb

vllm/attention/backends/hpu_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def forward(
223223
block_mapping=attn_metadata.block_mapping,
224224
block_bias=attn_metadata.attn_bias,
225225
block_scales=attn_metadata.block_scales,
226+
block_groups=attn_metadata.block_groups,
226227
scale=self.scale,
227228
matmul_qk_op=self.matmul_qk,
228229
matmul_av_op=self.matmul_av,

vllm/attention/ops/hpu_paged_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class HPUPagedAttentionMetadata:
2121
block_indices: Optional[torch.Tensor]
2222
block_offsets: Optional[torch.Tensor]
2323
block_scales: Optional[torch.Tensor]
24+
block_groups: Optional[torch.Tensor]
2425

2526

2627
class HPUPagedAttention:

vllm/worker/hpu_model_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ def _prepare_prompt(
907907
block_indices=block_indices,
908908
block_offsets=block_offsets,
909909
block_scales=None,
910+
block_groups=None,
910911
attn_bias=None,
911912
seq_lens_tensor=seq_lens_tensor,
912913
num_prefills=real_num_seqs,
@@ -1028,6 +1029,8 @@ def _prepare_decode(
10281029
len(block_list),
10291030
self.bucketing_global_state.decode_block_bucket_cfg)
10301031
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
1032+
block_groups = pad_list(block_mapping, block_bucket_size,
1033+
len(block_tables))
10311034
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
10321035
block_usage = pad_list(block_usage, block_bucket_size, 1)
10331036
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
@@ -1038,6 +1041,9 @@ def _prepare_decode(
10381041
block_mapping = torch.tensor(block_mapping,
10391042
dtype=torch.long,
10401043
device=self.device)
1044+
block_groups = torch.tensor(block_groups,
1045+
dtype=torch.long,
1046+
device=self.device)
10411047
block_usage = torch.tensor(block_usage,
10421048
dtype=self.model_config.dtype,
10431049
device=self.device)
@@ -1060,6 +1066,7 @@ def _prepare_decode(
10601066
block_indices=block_indices,
10611067
block_offsets=block_offsets,
10621068
block_scales=block_scales,
1069+
block_groups=block_groups,
10631070
attn_bias=None,
10641071
seq_lens_tensor=None,
10651072
num_prefills=0,
@@ -1271,7 +1278,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
12711278
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
12721279
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
12731280
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
1274-
'block_offsets', 'block_scales'
1281+
'block_offsets', 'block_scales', 'block_groups'
12751282
])
12761283
return attention_metadata
12771284

0 commit comments

Comments
 (0)