22import weakref
33from collections import defaultdict
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
5+ from typing import (TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type ,
6+ TypeVar , Union )
67
78import torch
89from torch import nn
3132
3233logger = init_logger (__name__ )
3334
35+ TModelInputForCPU = TypeVar ('TModelInputForCPU' , bound = "ModelInputForCPU" )
3436_PAD_SLOT_ID = - 1
3537
3638
@@ -60,10 +62,10 @@ def as_broadcastable_tensor_dict(
6062
6163 @classmethod
6264 def from_broadcasted_tensor_dict (
63- cls : Type ["ModelInputForCPU" ],
65+ cls : Type [TModelInputForCPU ],
6466 tensor_dict : Dict [str , Any ],
6567 attn_backend : Optional ["AttentionBackend" ] = None
66- ) -> "ModelInputForCPU" :
68+ ) -> TModelInputForCPU :
6769 if attn_backend is not None :
6870 tensor_dict = _init_attn_metadata_from_tensor_dict (
6971 attn_backend , tensor_dict )
@@ -255,11 +257,14 @@ def _prepare_prompt(
255257 slot_mapping .append (_PAD_SLOT_ID )
256258 continue
257259
258- block_number = block_table [i //
259- self .block_size ] # type: ignore
260- block_offset = i % self .block_size # type: ignore
261- slot = block_number * self .block_size + block_offset
262- slot_mapping .append (slot )
260+ # For encoder-only models, the block_table is None,
261+ # and there is no need to initialize the slot_mapping.
262+ if block_table is not None :
263+ block_number = block_table [i //
264+ self .block_size ] # type: ignore
265+ block_offset = i % self .block_size # type: ignore
266+ slot = block_number * self .block_size + block_offset
267+ slot_mapping .append (slot )
263268
264269 if any (input_mrope_positions ):
265270 input_positions = None # type: ignore
@@ -402,10 +407,12 @@ def _prepare_decode(
402407 )
403408
404409
405- class CPUModelRunner (ModelRunnerBase [ModelInputForCPU ]):
406- _model_input_cls : Type [ModelInputForCPUWithSamplingMetadata ] = (
407- ModelInputForCPUWithSamplingMetadata )
408- _builder_cls : Type [ModelInputForCPUBuilder ] = ModelInputForCPUBuilder
410+ class CPUModelRunnerBase (ModelRunnerBase [TModelInputForCPU ]):
411+ """
412+ Helper class for shared methods between CPU model runners.
413+ """
414+ _model_input_cls : Type [TModelInputForCPU ]
415+ _builder_cls : Type [ModelInputForCPUBuilder ]
409416
410417 def __init__ (
411418 self ,
@@ -448,20 +455,11 @@ def __init__(
448455 def load_model (self ) -> None :
449456 self .model = get_model (vllm_config = self .vllm_config )
450457
451- def make_model_input_from_broadcasted_tensor_dict (
452- self ,
453- tensor_dict : Dict [str , Any ],
454- ) -> ModelInputForCPUWithSamplingMetadata :
455- return ModelInputForCPUWithSamplingMetadata .from_broadcasted_tensor_dict ( # noqa: E501
456- tensor_dict ,
457- attn_backend = self .attn_backend ,
458- )
459-
460458 def _prepare_model_input_tensors (
461459 self ,
462460 seq_group_metadata_list : List [SequenceGroupMetadata ],
463461 finished_requests_ids : Optional [List [str ]] = None
464- ) -> ModelInputForCPUWithSamplingMetadata :
462+ ) -> TModelInputForCPU :
465463 """Helper method to prepare the model input based on a given sequence
466464 group. Prepares metadata needed for the base model forward pass but not
467465 metadata for possible additional steps, e.g., sampling.
@@ -473,6 +471,21 @@ def _prepare_model_input_tensors(
473471
474472 return builder .build () # type: ignore
475473
474+
475+ class CPUModelRunner (CPUModelRunnerBase [ModelInputForCPUWithSamplingMetadata ]):
476+ _model_input_cls : Type [ModelInputForCPUWithSamplingMetadata ] = (
477+ ModelInputForCPUWithSamplingMetadata )
478+ _builder_cls : Type [ModelInputForCPUBuilder ] = ModelInputForCPUBuilder
479+
480+ def make_model_input_from_broadcasted_tensor_dict (
481+ self ,
482+ tensor_dict : Dict [str , Any ],
483+ ) -> ModelInputForCPUWithSamplingMetadata :
484+ return ModelInputForCPUWithSamplingMetadata .from_broadcasted_tensor_dict ( # noqa: E501
485+ tensor_dict ,
486+ attn_backend = self .attn_backend ,
487+ )
488+
476489 def prepare_model_input (
477490 self ,
478491 seq_group_metadata_list : List [SequenceGroupMetadata ],
0 commit comments