Skip to content

Commit 00dcb05

Browse files
committed
rebase main
Signed-off-by: zhenwei <[email protected]>
1 parent c3b1ed8 commit 00dcb05

File tree

3 files changed

+25
-33
lines changed

3 files changed

+25
-33
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
7676
is_prompt: bool
7777
attn_bias: Optional[torch.Tensor]
7878
seq_lens_tensor: Optional[torch.Tensor]
79-
context_lens_tensor: Optional[torch.Tensor]
8079

8180

8281
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):

vllm/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,6 @@ def reset(self):
352352
self._index = 0
353353

354354

355-
@cache
356-
def is_fake_hpu() -> bool:
357-
return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'
358-
359-
360355
@cache
361356
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
362357
"""Returns the maximum shared memory per thread block in bytes."""

vllm/worker/hpu_model_runner.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from vllm.sampling_params import SamplingParams
4747
from vllm.sequence import (IntermediateTensors, SequenceData,
4848
SequenceGroupMetadata)
49-
from vllm.utils import (bind_kv_cache, is_fake_hpu, is_pin_memory_available,
49+
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
5050
make_tensor_with_pad)
5151
from vllm.worker.model_runner_base import (
5252
ModelRunnerBase, ModelRunnerInputBase,
@@ -345,8 +345,22 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
345345
mask = mask >= metadata.block_usage.unsqueeze(-1)
346346
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
347347
mask, -math.inf))
348-
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
349-
num_classes=batch_size)
348+
if os.environ.get('VLLM_USE_FAKE_HPU',
349+
'0') == '0' and htorch.utils.internal.is_lazy():
350+
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
351+
num_classes=batch_size)
352+
else:
353+
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
354+
# doesn't handle out of bounds classes so we need to convert
355+
# all negative values to 0 (block_mapping) or bs (block_groups)
356+
block_groups = metadata.block_groups.to(torch.long)
357+
block_mapping = torch.nn.functional.relu(block_groups)
358+
block_mapping = torch.nn.functional.one_hot(block_mapping,
359+
num_classes=batch_size)
360+
oob_values = block_groups.lt(0)
361+
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
362+
block_groups.masked_fill_(oob_values, batch_size)
363+
metadata = metadata._replace(block_groups=block_groups)
350364
block_mapping = block_mapping.to(dtype)
351365
metadata = metadata._replace(block_mapping=block_mapping,
352366
attn_bias=attn_bias)
@@ -365,8 +379,9 @@ def _set_block_scales(self, metadata, device):
365379
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
366380
dtype):
367381
if attn_metadata.is_prompt:
368-
attn_metadata = self._set_attn_bias(attn_metadata, batch_size,
369-
seq_len, device, dtype)
382+
meta = attn_metadata
383+
attn_metadata = self._set_attn_bias(meta, batch_size, seq_len,
384+
device, dtype)
370385
else:
371386
meta = attn_metadata
372387
attn_metadata = self._set_block_mapping(meta, batch_size, device,
@@ -925,11 +940,6 @@ def _prepare_prompt(
925940

926941
block_indices, block_offsets = precompute_indices_and_offsets(
927942
self.block_size, slot_mapping, True)
928-
context_lens_tensor = torch.tensor(context_lens,
929-
dtype=torch.long,
930-
device='cpu')
931-
context_lens_tensor = context_lens_tensor.to(self.device,
932-
non_blocking=True)
933943
attn_metadata = self.attn_backend.make_metadata(
934944
is_prompt=True,
935945
block_list=None,
@@ -941,7 +951,6 @@ def _prepare_prompt(
941951
block_groups=None,
942952
attn_bias=None,
943953
seq_lens_tensor=seq_lens_tensor,
944-
context_lens_tensor=context_lens_tensor,
945954
num_prefills=real_num_seqs,
946955
num_prefill_tokens=sum_query_len,
947956
num_decode_tokens=0,
@@ -967,7 +976,6 @@ def _prepare_prompt(
967976
def _prepare_decode(
968977
self,
969978
seq_group_metadata_list: List[SequenceGroupMetadata],
970-
output=None,
971979
) -> PrepareDecodeMetadata:
972980
input_tokens: List[List[int]] = []
973981
input_positions: List[List[int]] = []
@@ -998,9 +1006,8 @@ def _prepare_decode(
9981006

9991007
for seq_id in seq_ids:
10001008
seq_data = seq_group_metadata.seq_data[seq_id]
1001-
if output is None:
1002-
generation_token = seq_data.get_last_token_id()
1003-
input_tokens.append([generation_token])
1009+
generation_token = seq_data.get_last_token_id()
1010+
input_tokens.append([generation_token])
10041011

10051012
seq_len = seq_data.get_len()
10061013
position = seq_len - 1
@@ -1011,9 +1018,6 @@ def _prepare_decode(
10111018
seq_lens.append(seq_len)
10121019

10131020
block_table = seq_group_metadata.block_tables[seq_id]
1014-
num_fully_occupied_blocks = position // self.block_size
1015-
block_table = block_table[:num_fully_occupied_blocks + 1]
1016-
10171021
if len(block_table) == 0:
10181022
block_number = _PAD_BLOCK_ID
10191023
else:
@@ -1033,14 +1037,9 @@ def _prepare_decode(
10331037
block_table = block_table[-sliding_window_blocks:]
10341038
block_tables.append(block_table)
10351039

1036-
if output is None:
1037-
input_tokens = torch.tensor(input_tokens,
1038-
dtype=torch.long,
1039-
device=self.device)
1040-
else:
1041-
real_batch_size = len(seq_group_metadata_list)
1042-
input_tokens = output[:real_batch_size]
1043-
1040+
input_tokens = torch.tensor(input_tokens,
1041+
dtype=torch.long,
1042+
device=self.device)
10441043
input_positions = torch.tensor(input_positions,
10451044
dtype=torch.long,
10461045
device=self.device)
@@ -1112,7 +1111,6 @@ def _prepare_decode(
11121111
block_groups=block_groups,
11131112
attn_bias=None,
11141113
seq_lens_tensor=None,
1115-
context_lens_tensor=None,
11161114
num_prefills=0,
11171115
num_prefill_tokens=0,
11181116
num_decode_tokens=num_decode_tokens,

0 commit comments

Comments
 (0)