From 27780a135d9c87fb69996e202574d76df7bfec32 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Tue, 25 Nov 2025 08:51:40 +0000 Subject: [PATCH 1/4] make forward context manager pluggable Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/model_executor/models/qwen2_5_vl.py | 9 ++++++++- vllm/platforms/interface.py | 9 +++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8c707c2561af..13d0dbf2be05 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -51,7 +51,6 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv3dLayer @@ -1316,6 +1315,10 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] + + from vllm.platforms import current_platform + + set_forward_context = current_platform.get_forward_context_manager() with set_forward_context(None, self.vllm_config): if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( @@ -1371,6 +1374,10 @@ def _process_video_input( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] + + from vllm.platforms import current_platform + + set_forward_context = current_platform.get_forward_context_manager() with set_forward_context(None, self.vllm_config): if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1e6b53021f88..9e544832dc84 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -653,6 +653,15 @@ def check_max_model_len(cls, max_model_len: int) -> int: """ return max_model_len + @classmethod + def get_forward_context_manager(cls): + """ + Returns forward context manager for the current platform. + """ + from vllm.forward_context import set_forward_context + + return set_forward_context + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED From f4fe99bc1643a4098822807ad258c782322d2ff4 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Tue, 25 Nov 2025 08:52:22 +0000 Subject: [PATCH 2/4] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/model_executor/models/qwen2_5_vl.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 13d0dbf2be05..acfa7abdce8c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1314,10 +1314,9 @@ def _process_image_input( if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: - pixel_values = image_input["pixel_values"] - from vllm.platforms import current_platform + pixel_values = image_input["pixel_values"] set_forward_context = current_platform.get_forward_context_manager() with set_forward_context(None, self.vllm_config): if self.use_data_parallel: @@ -1373,10 +1372,9 @@ def _process_video_input( if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: - pixel_values_videos = video_input["pixel_values_videos"] - from vllm.platforms import current_platform + pixel_values_videos = video_input["pixel_values_videos"] set_forward_context = current_platform.get_forward_context_manager() with set_forward_context(None, self.vllm_config): if self.use_data_parallel: From 69f01ee106ab41c7d27c744ab91b172cf0f1c6e6 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Tue, 25 Nov 2025 09:09:32 +0000 Subject: [PATCH 3/4] minor fix Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/model_executor/models/qwen2_5_vl.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index acfa7abdce8c..5ab087ec1746 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1246,6 +1246,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model.make_empty_intermediate_tensors ) + from vllm.platforms import current_platform + + self.set_forward_context = current_platform.get_forward_context_manager() + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.language_model.model.aux_hidden_state_layers = layers @@ -1314,11 +1318,8 @@ def _process_image_input( if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: - from vllm.platforms import current_platform - pixel_values = image_input["pixel_values"] - set_forward_context = current_platform.get_forward_context_manager() - with set_forward_context(None, self.vllm_config): + with self.set_forward_context(None, self.vllm_config): if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" @@ -1372,11 +1373,8 @@ def _process_video_input( if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: - from vllm.platforms import current_platform - pixel_values_videos = video_input["pixel_values_videos"] - set_forward_context = current_platform.get_forward_context_manager() - with set_forward_context(None, self.vllm_config): + with self.set_forward_context(None, self.vllm_config): if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( self.visual, From aa2947c94008b0bdc9f0206f957d949b4f02894c Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Wed, 26 Nov 2025 02:07:29 +0000 Subject: [PATCH 4/4] minor fix Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/model_executor/models/qwen2_5_vl.py | 7 +------ vllm/platforms/interface.py | 3 ++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5ab087ec1746..983d8100545c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -80,6 +80,7 @@ ) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -352,8 +353,6 @@ def __init__( ) ) # On ROCm with FLASH_ATTN backend, upstream flash_attn is used - from vllm.platforms import current_platform - if ( current_platform.is_rocm() and self.attn_backend == AttentionBackendEnum.FLASH_ATTN @@ -418,8 +417,6 @@ def forward( ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. - from vllm.platforms import current_platform - # Never remove the next contiguous logic # Without it, hallucinations occur with the backend if current_platform.is_rocm(): @@ -1246,8 +1243,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model.make_empty_intermediate_tensors ) - from vllm.platforms import current_platform - self.set_forward_context = current_platform.get_forward_context_manager() def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9e544832dc84..4066f2eb8b3b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -6,6 +6,7 @@ import platform import random import sys +from collections.abc import Callable from datetime import timedelta from typing import TYPE_CHECKING, Any, NamedTuple @@ -654,7 +655,7 @@ def check_max_model_len(cls, max_model_len: int) -> int: return max_model_len @classmethod - def get_forward_context_manager(cls): + def get_forward_context_manager(cls) -> Callable: """ Returns forward context manager for the current platform. """