3333from torch import nn
3434from torch .nn .init import trunc_normal_
3535from transformers import PretrainedConfig
36+ from typing_extensions import NotRequired
3637
3738from vllm .attention import AttentionMetadata
3839from vllm .config import CacheConfig , MultiModalConfig
5253from vllm .model_executor .models .qwen2 import Qwen2Model
5354from vllm .model_executor .sampling_metadata import SamplingMetadata
5455from vllm .multimodal import MULTIMODAL_REGISTRY
56+ from vllm .multimodal .base import MultiModalInputs
5557from vllm .multimodal .image import cached_get_image_processor
5658from vllm .multimodal .utils import cached_get_tokenizer
5759from vllm .sequence import IntermediateTensors , SequenceData
6466}
6567
6668
69+ class MiniCPMVImageInput (TypedDict ):
70+ """Input mapper input with auxiliary data for computing image bounds."""
71+ image : Image .Image
72+
73+ # Image bounds token ids in 0-dim scaler tensor.
74+ im_start_id : torch .Tensor
75+ im_end_id : torch .Tensor
76+ slice_start_id : NotRequired [torch .Tensor ]
77+ slice_end_id : NotRequired [torch .Tensor ]
78+
79+
6780class MiniCPMVImagePixelInputs (TypedDict ):
6881 pixel_values : List [torch .Tensor ]
6982 """
@@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
88101 """
89102
90103
91- MiniCPMVImageInputs = MiniCPMVImagePixelInputs
92-
93104DEFAULT_LN = partial (nn .LayerNorm , eps = 1e-6 )
94105
95106
@@ -234,6 +245,25 @@ def forward(self, x: torch.Tensor,
234245 return x
235246
236247
248+ def _build_image_input (ctx : InputContext ,
249+ image : Image .Image ) -> MiniCPMVImageInput :
250+ tokenizer = cached_get_tokenizer (
251+ ctx .model_config .tokenizer ,
252+ trust_remote_code = ctx .model_config .trust_remote_code )
253+ if hasattr (tokenizer , "slice_start_id" ):
254+ return MiniCPMVImageInput (
255+ image = image ,
256+ im_start_id = torch .tensor (tokenizer .im_start_id ),
257+ im_end_id = torch .tensor (tokenizer .im_end_id ),
258+ slice_start_id = torch .tensor (tokenizer .slice_start_id ),
259+ slice_end_id = torch .tensor (tokenizer .slice_end_id ))
260+ else :
261+ return MiniCPMVImageInput (image = image ,
262+ im_start_id = torch .tensor (
263+ tokenizer .im_start_id ),
264+ im_end_id = torch .tensor (tokenizer .im_end_id ))
265+
266+
237267def get_version_by_config (config : PretrainedConfig ) -> Tuple [int , ...]:
238268 version_float = getattr (config , "version" , None )
239269
@@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
257287 return SequenceData .from_token_counts ((0 , seq_len ))
258288
259289
260- def dummy_image_for_minicpmv (hf_config : PretrainedConfig , num_images : int ):
290+ def dummy_image_for_minicpmv (ctx : InputContext , hf_config : PretrainedConfig ,
291+ num_images : int ):
261292 width = height = hf_config .image_size
262- image = Image .new ("RGB" , (width , height ), color = 0 )
263- return {"image" : image if num_images == 1 else [image ] * num_images }
293+ image = _build_image_input (ctx ,
294+ image = Image .new ("RGB" , (width , height ),
295+ color = 0 ))
296+ return {"image" : [image ] if num_images == 1 else [image ] * num_images }
264297
265298
266299def dummy_data_for_minicpmv (ctx : InputContext , seq_len : int ,
@@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
269302 num_images = mm_counts ["image" ]
270303
271304 seq_data = dummy_seq_data_for_minicpmv (seq_len , num_images )
272- mm_data = dummy_image_for_minicpmv (hf_config , num_images )
305+ mm_data = dummy_image_for_minicpmv (ctx , hf_config , num_images )
273306
274307 return seq_data , mm_data
275308
@@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
280313 return llm_inputs
281314 model_config = ctx .model_config
282315 version = get_version_by_config (model_config .hf_config )
283- tokenizer = cached_get_tokenizer (model_config .tokenizer ,
284- trust_remote_code = True )
316+ tokenizer = cached_get_tokenizer (
317+ model_config .tokenizer ,
318+ trust_remote_code = model_config .trust_remote_code )
285319 image_processor = cached_get_image_processor (model_config .tokenizer )
286320
287321 def get_placeholder (image_size : Tuple [int , int ], num_image : int ):
@@ -317,6 +351,10 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
317351 new_prompt = "" .join (new_prompt_chunks )
318352 new_token_ids = tokenizer .encode (new_prompt )
319353
354+ multi_modal_data ["image" ] = [
355+ _build_image_input (ctx , image ) for image in images
356+ ]
357+
320358 llm_inputs = LLMInputs (
321359 prompt_token_ids = new_token_ids ,
322360 prompt = new_prompt ,
@@ -325,6 +363,32 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
325363 return llm_inputs
326364
327365
366+ def input_mapper_for_minicpmv (ctx : InputContext , data : object ):
367+ model_config = ctx .model_config
368+
369+ image_processor = cached_get_image_processor (
370+ model_config .model , trust_remote_code = model_config .trust_remote_code )
371+ if image_processor is None :
372+ raise RuntimeError ("No HuggingFace processor is available "
373+ "to process the image object" )
374+
375+ if not isinstance (data , list ):
376+ raise ValueError (
377+ "Image input must be list of MiniCPMVImageInput, got (%s)" , data )
378+ batch_data = image_processor \
379+ .preprocess ([img ["image" ] for img in data ], return_tensors = "pt" ) \
380+ .data
381+
382+ if len (data ) > 0 :
383+ batch_data ["im_start_id" ] = data [0 ]["im_start_id" ]
384+ batch_data ["im_end_id" ] = data [0 ]["im_end_id" ]
385+ if "slice_start_id" in data [0 ]:
386+ batch_data ["slice_start_id" ] = data [0 ]["slice_start_id" ]
387+ batch_data ["slice_end_id" ] = data [0 ]["slice_end_id" ]
388+
389+ return MultiModalInputs (batch_data )
390+
391+
328392class MiniCPMVBaseModel (nn .Module , SupportsMultiModal ):
329393 """
330394 The abstract class of MiniCPMV can only be inherited, but cannot be
@@ -365,7 +429,7 @@ def __init__(
365429 def get_embedding (
366430 self ,
367431 input_ids : torch .Tensor ,
368- image_inputs : Optional [MiniCPMVImageInputs ],
432+ image_inputs : Optional [MiniCPMVImagePixelInputs ],
369433 ) -> Tuple [torch .Tensor , torch .Tensor ]:
370434 vlm_embedding : torch .Tensor = self .llm .embed_tokens (input_ids )
371435 if hasattr (self .config , "scale_emb" ):
@@ -393,14 +457,20 @@ def get_embedding(
393457
394458 return vlm_embedding , vision_hidden_states
395459
396- def _get_image_bounds (self , input_ids : torch .Tensor ) -> torch .Tensor :
397- tokenizer = cached_get_tokenizer (self .config ._name_or_path ,
398- trust_remote_code = True )
399- start_cond = input_ids == tokenizer .im_start_id
400- end_cond = input_ids == tokenizer .im_end_id
401- if hasattr (tokenizer , "slice_start_id" ):
402- start_cond |= (input_ids == tokenizer .slice_start_id )
403- end_cond |= (input_ids == tokenizer .slice_end_id )
460+ def _get_image_bounds (
461+ self ,
462+ input_ids : torch .Tensor ,
463+ im_start_id : torch .Tensor ,
464+ im_end_id : torch .Tensor ,
465+ slice_start_id : Optional [torch .Tensor ] = None ,
466+ slice_end_id : Optional [torch .Tensor ] = None ) -> torch .Tensor :
467+ # All the images in the batch should share the same special image
468+ # bound token ids.
469+ start_cond = input_ids == im_start_id [0 ]
470+ end_cond = input_ids == im_end_id [0 ]
471+ if slice_start_id is not None :
472+ start_cond |= (input_ids == slice_start_id [0 ])
473+ end_cond |= (input_ids == slice_end_id [0 ])
404474
405475 image_start_tokens , = torch .where (start_cond )
406476 image_start_tokens += 1
@@ -419,7 +489,7 @@ def _parse_and_validate_inputs(
419489 self ,
420490 input_ids : torch .Tensor ,
421491 ** kwargs : object ,
422- ) -> Optional [MiniCPMVImageInputs ]:
492+ ) -> Optional [MiniCPMVImagePixelInputs ]:
423493 pixel_values = kwargs .pop ("pixel_values" , [])
424494 tgt_sizes = kwargs .pop ("tgt_sizes" , [])
425495
@@ -456,8 +526,17 @@ def _parse_and_validate_inputs(
456526 if len (pixel_values_flat ) == 0 :
457527 return None
458528
459- return MiniCPMVImageInputs (
460- image_bounds = self ._get_image_bounds (input_ids ),
529+ im_start_id = kwargs .pop ("im_start_id" , None )
530+ im_end_id = kwargs .pop ("im_end_id" , None )
531+ slice_start_id = kwargs .pop ("slice_start_id" , None )
532+ slice_end_id = kwargs .pop ("slice_end_id" , None )
533+ if im_start_id is None :
534+ return None
535+
536+ return MiniCPMVImagePixelInputs (
537+ image_bounds = self ._get_image_bounds (input_ids , im_start_id ,
538+ im_end_id , slice_start_id ,
539+ slice_end_id ),
461540 pixel_values = pixel_values_flat ,
462541 tgt_sizes = torch .stack (tgt_sizes_flat ),
463542 )
@@ -564,8 +643,8 @@ def get_vision_embedding(
564643 ) -> torch .Tensor :
565644 raise NotImplementedError
566645
567- def get_vision_hidden_states (self ,
568- data : MiniCPMVImageInputs ) -> torch .Tensor :
646+ def get_vision_hidden_states (
647+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
569648 raise NotImplementedError
570649
571650 def is_default_weight_loading (self , name : str ) -> bool :
@@ -654,8 +733,8 @@ def get_vision_embedding(
654733 res .append (self .resampler (vision_embedding , tgt_size ))
655734 return torch .vstack (res )
656735
657- def get_vision_hidden_states (self ,
658- data : MiniCPMVImageInputs ) -> torch .Tensor :
736+ def get_vision_hidden_states (
737+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
659738 pixel_values = data ["pixel_values" ]
660739
661740 return self .get_vision_embedding (pixel_values )
@@ -713,8 +792,8 @@ def get_vision_embedding(
713792 vision_embedding = self .resampler (vision_embedding , tgt_sizes )
714793 return vision_embedding
715794
716- def get_vision_hidden_states (self ,
717- data : MiniCPMVImageInputs ) -> torch .Tensor :
795+ def get_vision_hidden_states (
796+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
718797 pixel_values = data ["pixel_values" ]
719798 tgt_sizes = data ["tgt_sizes" ]
720799
@@ -807,8 +886,8 @@ def get_vision_embedding(
807886 ).last_hidden_state
808887 return vision_embedding
809888
810- def get_vision_hidden_states (self ,
811- data : MiniCPMVImageInputs ) -> torch .Tensor :
889+ def get_vision_hidden_states (
890+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
812891 pixel_values = data ["pixel_values" ]
813892 tgt_sizes = data ["tgt_sizes" ]
814893
@@ -851,7 +930,7 @@ def is_default_weight_loading(self, name: str) -> bool:
851930}
852931
853932
854- @MULTIMODAL_REGISTRY .register_image_input_mapper ()
933+ @MULTIMODAL_REGISTRY .register_image_input_mapper (input_mapper_for_minicpmv )
855934@MULTIMODAL_REGISTRY .register_max_image_tokens (get_max_minicpmv_image_tokens )
856935@INPUT_REGISTRY .register_dummy_data (dummy_data_for_minicpmv )
857936@INPUT_REGISTRY .register_input_processor (input_processor_for_minicpmv )
0 commit comments