4646from vllm .sampling_params import SamplingParams
4747from vllm .sequence import (IntermediateTensors , SequenceData ,
4848 SequenceGroupMetadata )
49- from vllm .utils import (bind_kv_cache , is_fake_hpu , is_pin_memory_available ,
49+ from vllm .utils import (bind_kv_cache , is_pin_memory_available ,
5050 make_tensor_with_pad )
5151from vllm .worker .model_runner_base import (
5252 ModelRunnerBase , ModelRunnerInputBase ,
@@ -345,8 +345,22 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
345345 mask = mask >= metadata .block_usage .unsqueeze (- 1 )
346346 attn_bias = (torch .zeros_like (mask , dtype = dtype ).masked_fill_ (
347347 mask , - math .inf ))
348- block_mapping = torch .nn .functional .one_hot (metadata .block_groups ,
349- num_classes = batch_size )
348+ if os .environ .get ('VLLM_USE_FAKE_HPU' ,
349+ '0' ) == '0' and htorch .utils .internal .is_lazy ():
350+ block_mapping = torch .nn .functional .one_hot (metadata .block_groups ,
351+ num_classes = batch_size )
352+ else :
353+ # Unfortunately one_hot on CPU/torch.compile mode/eager mode
354+ # doesn't handle out of bounds classes so we need to convert
355+ # all negative values to 0 (block_mapping) or bs (block_groups)
356+ block_groups = metadata .block_groups .to (torch .long )
357+ block_mapping = torch .nn .functional .relu (block_groups )
358+ block_mapping = torch .nn .functional .one_hot (block_mapping ,
359+ num_classes = batch_size )
360+ oob_values = block_groups .lt (0 )
361+ block_mapping .masked_fill_ (oob_values .unsqueeze (- 1 ), 0 )
362+ block_groups .masked_fill_ (oob_values , batch_size )
363+ metadata = metadata ._replace (block_groups = block_groups )
350364 block_mapping = block_mapping .to (dtype )
351365 metadata = metadata ._replace (block_mapping = block_mapping ,
352366 attn_bias = attn_bias )
@@ -365,8 +379,9 @@ def _set_block_scales(self, metadata, device):
365379 def _update_metadata (self , attn_metadata , batch_size , seq_len , device ,
366380 dtype ):
367381 if attn_metadata .is_prompt :
368- attn_metadata = self ._set_attn_bias (attn_metadata , batch_size ,
369- seq_len , device , dtype )
382+ meta = attn_metadata
383+ attn_metadata = self ._set_attn_bias (meta , batch_size , seq_len ,
384+ device , dtype )
370385 else :
371386 meta = attn_metadata
372387 attn_metadata = self ._set_block_mapping (meta , batch_size , device ,
@@ -925,11 +940,6 @@ def _prepare_prompt(
925940
926941 block_indices , block_offsets = precompute_indices_and_offsets (
927942 self .block_size , slot_mapping , True )
928- context_lens_tensor = torch .tensor (context_lens ,
929- dtype = torch .long ,
930- device = 'cpu' )
931- context_lens_tensor = context_lens_tensor .to (self .device ,
932- non_blocking = True )
933943 attn_metadata = self .attn_backend .make_metadata (
934944 is_prompt = True ,
935945 block_list = None ,
@@ -941,7 +951,6 @@ def _prepare_prompt(
941951 block_groups = None ,
942952 attn_bias = None ,
943953 seq_lens_tensor = seq_lens_tensor ,
944- context_lens_tensor = context_lens_tensor ,
945954 num_prefills = real_num_seqs ,
946955 num_prefill_tokens = sum_query_len ,
947956 num_decode_tokens = 0 ,
@@ -967,7 +976,6 @@ def _prepare_prompt(
967976 def _prepare_decode (
968977 self ,
969978 seq_group_metadata_list : List [SequenceGroupMetadata ],
970- output = None ,
971979 ) -> PrepareDecodeMetadata :
972980 input_tokens : List [List [int ]] = []
973981 input_positions : List [List [int ]] = []
@@ -998,9 +1006,8 @@ def _prepare_decode(
9981006
9991007 for seq_id in seq_ids :
10001008 seq_data = seq_group_metadata .seq_data [seq_id ]
1001- if output is None :
1002- generation_token = seq_data .get_last_token_id ()
1003- input_tokens .append ([generation_token ])
1009+ generation_token = seq_data .get_last_token_id ()
1010+ input_tokens .append ([generation_token ])
10041011
10051012 seq_len = seq_data .get_len ()
10061013 position = seq_len - 1
@@ -1011,9 +1018,6 @@ def _prepare_decode(
10111018 seq_lens .append (seq_len )
10121019
10131020 block_table = seq_group_metadata .block_tables [seq_id ]
1014- num_fully_occupied_blocks = position // self .block_size
1015- block_table = block_table [:num_fully_occupied_blocks + 1 ]
1016-
10171021 if len (block_table ) == 0 :
10181022 block_number = _PAD_BLOCK_ID
10191023 else :
@@ -1033,14 +1037,9 @@ def _prepare_decode(
10331037 block_table = block_table [- sliding_window_blocks :]
10341038 block_tables .append (block_table )
10351039
1036- if output is None :
1037- input_tokens = torch .tensor (input_tokens ,
1038- dtype = torch .long ,
1039- device = self .device )
1040- else :
1041- real_batch_size = len (seq_group_metadata_list )
1042- input_tokens = output [:real_batch_size ]
1043-
1040+ input_tokens = torch .tensor (input_tokens ,
1041+ dtype = torch .long ,
1042+ device = self .device )
10441043 input_positions = torch .tensor (input_positions ,
10451044 dtype = torch .long ,
10461045 device = self .device )
@@ -1112,7 +1111,6 @@ def _prepare_decode(
11121111 block_groups = block_groups ,
11131112 attn_bias = None ,
11141113 seq_lens_tensor = None ,
1115- context_lens_tensor = None ,
11161114 num_prefills = 0 ,
11171115 num_prefill_tokens = 0 ,
11181116 num_decode_tokens = num_decode_tokens ,
0 commit comments