1414from vllm .attention .backends .abstract import AttentionType
1515from vllm .attention .layer import Attention
1616from vllm .config import VllmConfig
17- from vllm .forward_context import get_forward_context , set_forward_context
17+ from vllm .forward_context import set_forward_context
1818from vllm .inputs import INPUT_REGISTRY
1919from vllm .logger import init_logger
2020from vllm .model_executor .model_loader import get_model
@@ -416,8 +416,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
416416 num_scheduled_tokens_per_req )
417417
418418 # Do the padding and copy the tensors to the TPU.
419- padded_total_num_scheduled_tokens = _get_padded_number (
420- total_num_scheduled_tokens , NUM_QUERIES_PER_BLOCK )
419+ padded_total_num_scheduled_tokens = _get_padded_token_len (
420+ total_num_scheduled_tokens )
421421 self .input_ids = self .input_ids_cpu [:
422422 padded_total_num_scheduled_tokens ].to (
423423 self .device )
@@ -428,23 +428,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
428428 slot_mapping = self .slot_mapping_cpu [:
429429 padded_total_num_scheduled_tokens ].to (
430430 self .device )
431- padded_block_table = self .block_table_cpu [:
432- padded_total_num_scheduled_tokens ]
433- padded_block_table [:num_reqs , :self .max_num_blocks_per_req ] = (
431+ block_tables = self .block_table_cpu [:self .max_num_reqs ]
432+ block_tables [:num_reqs , :self .max_num_blocks_per_req ] = (
434433 self .input_batch .block_table .get_cpu_tensor ()[:num_reqs ])
435- padded_block_table = padded_block_table .to (self .device )
436- query_start_loc = self .query_start_loc_cpu [:
437- padded_total_num_scheduled_tokens
438- + 1 ].to (self .device )
439- seq_lens = self .seq_lens_cpu [:padded_total_num_scheduled_tokens ].to (
434+ block_tables = block_tables .to (self .device )
435+ query_start_loc = self .query_start_loc_cpu [:self .max_num_reqs + 1 ].to (
440436 self .device )
437+ seq_lens = self .seq_lens_cpu [:self .max_num_reqs ].to (self .device )
441438
442439 attn_metadata = PallasMetadata (
443440 slot_mapping = slot_mapping ,
444- block_tables = padded_block_table ,
441+ block_tables = block_tables ,
445442 context_lens = seq_lens ,
446443 query_start_loc = query_start_loc ,
447- num_seqs = num_reqs ,
444+ num_seqs = torch .tensor ([num_reqs ],
445+ dtype = torch .int32 ,
446+ device = self .device ),
448447 )
449448 # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
450449 # request in the batch. While we should not sample any token from this
@@ -693,29 +692,34 @@ def _dummy_run(
693692 dtype = torch .int32 ,
694693 device = self .device )
695694 inputs_embeds = None
695+ actual_num_reqs = min (num_tokens , self .max_num_reqs )
696696 position_ids = torch .zeros (num_tokens ,
697697 dtype = torch .int32 ,
698698 device = self .device )
699699 slot_mapping = torch .zeros (num_tokens ,
700700 dtype = torch .int64 ,
701701 device = self .device )
702- block_tables = torch .zeros ((num_tokens , self .block_table_cpu .shape [1 ]),
703- dtype = torch .int32 ,
704- device = self .device )
705- query_lens = [1 ] * num_tokens
702+ block_tables = torch .zeros (
703+ (self .max_num_reqs , self .block_table_cpu .shape [1 ]),
704+ dtype = torch .int32 ,
705+ device = self .device )
706+ query_lens = [1 ] * self .max_num_reqs
706707 query_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens ,
707708 dtype = torch .int32 ),
708709 dim = 0 ,
709710 dtype = torch .int32 ).to (self .device )
710- context_lens = torch .ones ((num_tokens , ),
711+ context_lens = torch .ones ((self . max_num_reqs , ),
711712 dtype = torch .int32 ,
712713 device = self .device )
714+ num_seqs = torch .tensor ([actual_num_reqs ],
715+ dtype = torch .int32 ,
716+ device = self .device )
713717 attn_metadata = PallasMetadata (
714718 slot_mapping = slot_mapping ,
715719 block_tables = block_tables ,
716720 context_lens = context_lens ,
717721 query_start_loc = query_start_loc ,
718- num_seqs = num_tokens ,
722+ num_seqs = num_seqs ,
719723 )
720724
721725 if self .is_multimodal_model :
@@ -724,9 +728,6 @@ def _dummy_run(
724728 torch ._dynamo .mark_dynamic (input_ids , 0 )
725729 torch ._dynamo .mark_dynamic (position_ids , 0 )
726730 torch ._dynamo .mark_dynamic (attn_metadata .slot_mapping , 0 )
727- torch ._dynamo .mark_dynamic (attn_metadata .block_tables , 0 )
728- torch ._dynamo .mark_dynamic (attn_metadata .query_start_loc , 0 )
729- torch ._dynamo .mark_dynamic (attn_metadata .context_lens , 0 )
730731
731732 with set_forward_context (attn_metadata , self .vllm_config , 0 ):
732733 assert self .model is not None
@@ -817,28 +818,6 @@ def forward(
817818 inputs_embeds: The input embeddings of shape [num_tokens,
818819 hidden_size]. It is used for multimodal models.
819820 """
820- # Skip this in memory profiling at initialization.
821- if kv_caches [0 ][0 ].numel () > 0 :
822- attn_metadata = get_forward_context ().attn_metadata
823- # index_copy_(slot_mapping) only works when the inserted dimension
824- # is 0. However, the KV cache in the Pallas backend has the shape
825- # [num_kv_heads, num_blocks, block_size, head_size]. To make it
826- # work, we need to flatten the first three dimensions and modify
827- # the slot_mapping accordingly.
828- # kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
829- num_kv_heads , num_blocks , block_size , _ = kv_caches [0 ][0 ].shape
830- slot_mapping = attn_metadata .slot_mapping
831- slot_mapping = slot_mapping .flatten ()
832- head_indicies = torch .arange (0 ,
833- num_kv_heads ,
834- device = slot_mapping .device ,
835- dtype = slot_mapping .dtype )
836- head_indicies *= block_size * num_blocks
837- slot_mapping = slot_mapping .repeat_interleave (num_kv_heads ).view (
838- - 1 , num_kv_heads )
839- slot_mapping = slot_mapping + head_indicies .view (1 , - 1 )
840- slot_mapping = slot_mapping .flatten ()
841- attn_metadata .slot_mapping = slot_mapping
842821
843822 assert self .model is not None
844823 hidden_states = self .model (
@@ -866,3 +845,9 @@ def get_input_embeddings(self, *args, **kwargs):
866845
867846def _get_padded_number (n : int , multiple : int ) -> int :
868847 return ((n + multiple - 1 ) // multiple ) * multiple
848+
849+
850+ def _get_padded_token_len (x : int ) -> int :
851+ if x <= 16 :
852+ return 16
853+ return 1 << (x - 1 ).bit_length ()
0 commit comments