4242from vllm .inputs import INPUT_REGISTRY
4343from vllm .logger import logger
4444from vllm .model_executor .layers .fused_moe import FusedMoE
45+ from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
4546from vllm .model_executor .model_loader import get_model
46- from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
47+ from vllm .multimodal import MULTIMODAL_REGISTRY
48+ from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
49+ from vllm .multimodal .utils import group_mm_inputs_by_modality
4750from vllm .sampling_params import SamplingType
4851from vllm .sequence import IntermediateTensors
4952from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
6164from vllm .v1 .utils import bind_kv_cache
6265from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
6366from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
67+ from vllm .v1 .worker .utils import (gather_mm_placeholders ,
68+ sanity_check_mm_encoder_outputs ,
69+ scatter_mm_placeholders )
6470
6571from vllm_ascend .ascend_config import get_ascend_config
6672from vllm_ascend .attention .attention import AttentionMaskBuilder
@@ -362,6 +368,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
362368 # Remove finished requests from the cached states.
363369 for req_id in scheduler_output .finished_req_ids :
364370 self .requests .pop (req_id , None )
371+ self .encoder_cache .pop (req_id , None )
365372 # Remove the finished requests from the persistent batch.
366373 # NOTE(woosuk): There could be an edge case where finished_req_ids and
367374 # scheduled_req_ids overlap. This happens when a request is aborted and
@@ -374,6 +381,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
374381 if req_index is not None :
375382 removed_req_indices .append (req_index )
376383
384+ # Free the cached encoder outputs.
385+ for req_id , input_id in scheduler_output .free_encoder_input_ids :
386+ encoder_outputs = self .encoder_cache .get (req_id )
387+ if encoder_outputs is not None :
388+ encoder_outputs .pop (input_id , None )
389+ if not encoder_outputs :
390+ self .encoder_cache .pop (req_id , None )
391+
377392 # Remove the unscheduled requests from the persistent batch.
378393 # NOTE(woosuk): The unscheduled requests are either preempted requests
379394 # or running requests that are not scheduled in this step. We remove
@@ -415,6 +430,43 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
415430 lora_request = new_req_data .lora_request ,
416431 )
417432
433+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
434+ if self .uses_mrope :
435+ image_grid_thw = []
436+ video_grid_thw = []
437+ second_per_grid_ts = []
438+ audio_feature_lengths = []
439+ use_audio_in_video = False
440+ for mm_input in self .requests [req_id ].mm_inputs :
441+ if mm_input .get ("image_grid_thw" ) is not None :
442+ image_grid_thw .extend (
443+ mm_input ["image_grid_thw" ].tolist ())
444+ if mm_input .get ("video_grid_thw" ) is not None :
445+ video_grid_thw .extend (
446+ mm_input ["video_grid_thw" ].tolist ())
447+ if mm_input .get ("second_per_grid_ts" ) is not None :
448+ second_per_grid_ts .extend (
449+ mm_input ["second_per_grid_ts" ])
450+ if mm_input .get ("audio_feature_lengths" ) is not None :
451+ audio_feature_lengths .extend (
452+ mm_input ["audio_feature_lengths" ])
453+ if mm_input .get ("use_audio_in_video" ) is True :
454+ use_audio_in_video = True
455+
456+ hf_config = self .model_config .hf_config
457+
458+ self .requests [req_id ].mrope_positions , \
459+ self .requests [req_id ].mrope_position_delta = \
460+ MRotaryEmbedding .get_input_positions_tensor (
461+ self .requests [req_id ].prompt_token_ids ,
462+ hf_config = hf_config ,
463+ image_grid_thw = image_grid_thw ,
464+ video_grid_thw = video_grid_thw ,
465+ second_per_grid_ts = second_per_grid_ts ,
466+ audio_feature_lengths = audio_feature_lengths ,
467+ use_audio_in_video = use_audio_in_video ,
468+ )
469+
418470 req_ids_to_add .append (req_id )
419471
420472 # Update the states of the running/resumed requests.
@@ -535,6 +587,166 @@ def _make_attention_mask(self, seq_lens, query_lens, position,
535587 else :
536588 return None
537589
590+ def _calc_mrope_positions (self , scheduler_output : "SchedulerOutput" ):
591+ mrope_pos_ptr = 0
592+ for index , req_id in enumerate (self .input_batch .req_ids ):
593+ req = self .requests [req_id ]
594+ assert req .mrope_positions is not None
595+
596+ num_computed_tokens = \
597+ self .input_batch .num_computed_tokens_cpu [index ]
598+ num_scheduled_tokens = \
599+ scheduler_output .num_scheduled_tokens [req_id ]
600+ num_prompt_tokens = len (req .prompt_token_ids )
601+
602+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens :
603+ prompt_part_len = max (0 ,
604+ num_prompt_tokens - num_computed_tokens )
605+ completion_part_len = max (
606+ 0 , num_scheduled_tokens - prompt_part_len )
607+ else :
608+ prompt_part_len = num_scheduled_tokens
609+ completion_part_len = 0
610+
611+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
612+
613+ if prompt_part_len > 0 :
614+ # prompt's mrope_positions are pre-computed
615+ dst_start = mrope_pos_ptr
616+ dst_end = mrope_pos_ptr + prompt_part_len
617+ src_start = num_computed_tokens
618+ src_end = num_computed_tokens + prompt_part_len
619+
620+ self .mrope_positions_cpu [:, dst_start :dst_end ] = \
621+ req .mrope_positions [:,src_start :src_end ]
622+
623+ mrope_pos_ptr += prompt_part_len
624+
625+ if completion_part_len > 0 :
626+ # compute completion's mrope_positions on-the-fly
627+ dst_start = mrope_pos_ptr
628+ dst_end = mrope_pos_ptr + completion_part_len
629+
630+ self .mrope_positions_cpu [:, dst_start :dst_end ] = \
631+ MRotaryEmbedding .get_next_input_positions_tensor (
632+ req .mrope_position_delta ,
633+ context_len = num_computed_tokens +
634+ prompt_part_len ,
635+ seq_len = num_computed_tokens +
636+ prompt_part_len +
637+ completion_part_len ,
638+ )
639+
640+ mrope_pos_ptr += completion_part_len
641+
642+ def _execute_mm_encoder (self , scheduler_output : "SchedulerOutput" ):
643+ scheduled_encoder_inputs = scheduler_output .scheduled_encoder_inputs
644+ if not scheduled_encoder_inputs :
645+ return
646+
647+ # Batch the multi-modal inputs.
648+ mm_inputs = list [MultiModalKwargs ]()
649+ req_ids_pos = list [tuple [str , int , PlaceholderRange ]]()
650+ for req_id , encoder_input_ids in scheduled_encoder_inputs .items ():
651+ req_state = self .requests [req_id ]
652+
653+ for mm_input_id in encoder_input_ids :
654+ mm_inputs .append (req_state .mm_inputs [mm_input_id ])
655+ req_ids_pos .append (
656+ (req_id , mm_input_id , req_state .mm_positions [mm_input_id ]))
657+
658+ # Batch mm inputs as much as we can: if a request in the batch has
659+ # multiple modalities or a different modality than the previous one,
660+ # we process it separately to preserve item order.
661+ # FIXME(ywang96): This is a hacky way to deal with multiple modalities
662+ # in the same batch while still being able to benefit from batching
663+ # multimodal inputs. The proper solution should be reordering the
664+ # encoder outputs.
665+ grouped_mm_inputs_list = group_mm_inputs_by_modality (mm_inputs )
666+
667+ encoder_outputs = []
668+ for grouped_mm_inputs in grouped_mm_inputs_list :
669+ batched_mm_inputs = MultiModalKwargs .batch (grouped_mm_inputs )
670+ batched_mm_inputs = MultiModalKwargs .as_kwargs (batched_mm_inputs ,
671+ device = self .device )
672+
673+ # Run the encoder.
674+ # `curr_group_outputs` is either of the following:
675+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
676+ # in case feature_size is fixed across all multimodal items.
677+ # 2. A list or tuple (length: num_items) of tensors, each of shape
678+ # (feature_size, hidden_size) in case the feature size is dynamic
679+ # depending on the input multimodal items.
680+ curr_group_outputs = self .model .get_multimodal_embeddings (
681+ ** batched_mm_inputs )
682+
683+ sanity_check_mm_encoder_outputs (
684+ curr_group_outputs ,
685+ expected_num_items = len (grouped_mm_inputs ),
686+ )
687+
688+ for output in curr_group_outputs :
689+ encoder_outputs .append (output )
690+
691+ # Cache the encoder outputs.
692+ for (req_id , input_id , pos_info ), output in zip (
693+ req_ids_pos ,
694+ encoder_outputs ,
695+ ):
696+ if req_id not in self .encoder_cache :
697+ self .encoder_cache [req_id ] = {}
698+
699+ self .encoder_cache [req_id ][input_id ] = scatter_mm_placeholders (
700+ output ,
701+ is_embed = pos_info .is_embed ,
702+ )
703+
704+ def _gather_mm_embeddings (
705+ self ,
706+ scheduler_output : "SchedulerOutput" ,
707+ ) -> list [torch .Tensor ]:
708+ mm_embeds : list [torch .Tensor ] = []
709+ for req_id in self .input_batch .req_ids :
710+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens [
711+ req_id ]
712+ req_state = self .requests [req_id ]
713+ num_computed_tokens = req_state .num_computed_tokens
714+ mm_positions = req_state .mm_positions
715+ for i , pos_info in enumerate (mm_positions ):
716+ start_pos = pos_info .offset
717+ num_encoder_tokens = pos_info .length
718+
719+ # The encoder output is needed if the two ranges overlap:
720+ # [num_computed_tokens,
721+ # num_computed_tokens + num_scheduled_tokens) and
722+ # [start_pos, start_pos + num_encoder_tokens)
723+ if start_pos >= num_computed_tokens + num_scheduled_tokens :
724+ # The encoder output is not needed in this step.
725+ break
726+ if start_pos + num_encoder_tokens <= num_computed_tokens :
727+ # The encoder output is already processed and stored
728+ # in the decoder's KV cache.
729+ continue
730+
731+ start_idx = max (num_computed_tokens - start_pos , 0 )
732+ end_idx = min (
733+ num_computed_tokens - start_pos + num_scheduled_tokens ,
734+ num_encoder_tokens )
735+ assert start_idx < end_idx
736+ assert req_id in self .encoder_cache
737+ assert i in self .encoder_cache [req_id ]
738+ encoder_output = self .encoder_cache [req_id ][i ]
739+
740+ if (is_embed := pos_info .is_embed ) is not None :
741+ is_embed = is_embed [start_idx :end_idx ]
742+
743+ mm_embeds_item = gather_mm_placeholders (
744+ encoder_output [start_idx :end_idx ],
745+ is_embed = is_embed ,
746+ )
747+ mm_embeds .append (mm_embeds_item )
748+ return mm_embeds
749+
538750 def _process_reqs (
539751 self ,
540752 scheduler_output : "SchedulerOutput" ,
@@ -594,6 +806,17 @@ def _process_reqs(
594806 arange ,
595807 out = positions_np )
596808
809+ # Calculate M-RoPE positions.
810+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
811+ if self .uses_mrope :
812+ self ._calc_mrope_positions (scheduler_output )
813+
814+ if self .uses_mrope :
815+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
816+ self .mrope_positions [:, :total_num_scheduled_tokens ].copy_ (
817+ self .mrope_positions_cpu [:, :total_num_scheduled_tokens ],
818+ non_blocking = True )
819+
597820 self .positions [:total_num_scheduled_tokens ].copy_ (
598821 self .positions_cpu [:total_num_scheduled_tokens ], non_blocking = True )
599822 positions = self .positions [:num_input_tokens ]
@@ -706,6 +929,43 @@ def _process_reqs(
706929 input_ids = self .input_ids [:padded_batch_size ]
707930 positions = self .positions [:padded_batch_size ]
708931
932+ # prepare the MRoPE for mllm if using multimodal
933+ num_input_tokens = total_num_scheduled_tokens
934+ # _prepare_inputs may reorder the batch, so we must gather multi
935+ # modal outputs after that to ensure the correct order
936+ if self .is_multimodal_model :
937+ # Run the multimodal encoder if any.
938+ self ._execute_mm_encoder (scheduler_output )
939+ mm_embeds = self ._gather_mm_embeddings (scheduler_output )
940+ else :
941+ mm_embeds = []
942+
943+ if self .is_multimodal_model :
944+ # NOTE(woosuk): To unify token ids and soft tokens (vision
945+ # embeddings), we always use embeddings (rather than token ids)
946+ # as input to the multimodal model, even when the input is text.
947+ input_ids = self .input_ids [:num_input_tokens ]
948+ if mm_embeds :
949+ inputs_embeds = self .model .get_input_embeddings (
950+ input_ids , mm_embeds )
951+ else :
952+ inputs_embeds = self .model .get_input_embeddings (input_ids )
953+ # TODO(woosuk): Avoid the copy. Optimize.
954+ self .inputs_embeds [:num_input_tokens ].copy_ (inputs_embeds )
955+ inputs_embeds = self .inputs_embeds [:num_input_tokens ]
956+ input_ids = None
957+ else :
958+ # For text-only models, we use token ids as input.
959+ # While it is possible to use embeddings as input just like the
960+ # multimodal models, it is not desirable for performance since
961+ # then the embedding layer is not included in the CUDA graph.
962+ input_ids = self .input_ids [:num_input_tokens ]
963+ inputs_embeds = None
964+ if self .uses_mrope :
965+ positions = self .mrope_positions [:, :num_input_tokens ]
966+ else :
967+ positions = self .positions [:num_input_tokens ]
968+
709969 # Run forward pass
710970 with set_forward_context (attn_metadata ,
711971 self .vllm_config ,
@@ -722,7 +982,7 @@ def _process_reqs(
722982 input_ids = input_ids ,
723983 positions = positions ,
724984 intermediate_tensors = intermediate_tensors ,
725- inputs_embeds = None ,
985+ inputs_embeds = inputs_embeds ,
726986 ** model_kwargs ,
727987 )
728988 else :
@@ -731,7 +991,7 @@ def _process_reqs(
731991 input_ids = input_ids ,
732992 positions = positions ,
733993 intermediate_tensors = intermediate_tensors ,
734- inputs_embeds = None ,
994+ inputs_embeds = inputs_embeds ,
735995 ** model_kwargs ,
736996 )
737997
@@ -1214,8 +1474,11 @@ def _dummy_run(
12141474 return hidden_states
12151475
12161476 def profile_run (self ) -> None :
1217- # Profile with multimodal encoder & encoder cache.
1218- self ._profile_multimodal ()
1477+ # FIXME Profile with multimodal encoder & encoder cache.
1478+ # current _profile_multimodal() using PyTorch SDPA backend method not
1479+ # support for window/full attn to reduce Memcpy operations, so will cause
1480+ # Out Of Memory problem, so we currently don't use self._profile_multimodal()
1481+ # self._profile_multimodal()
12191482
12201483 # For profile, have maximum num_reqs and that collectively have
12211484 # maximum num_tokens.
0 commit comments