Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv3dLayer
Expand Down Expand Up @@ -1247,6 +1246,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.language_model.make_empty_intermediate_tensors
)

from vllm.platforms import current_platform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to import this lazily, just import it from top level

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to import this lazily, just import it from top level

done.


self.set_forward_context = current_platform.get_forward_context_manager()

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers

Expand Down Expand Up @@ -1316,7 +1319,7 @@ def _process_image_input(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"]
with set_forward_context(None, self.vllm_config):
with self.set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
Expand Down Expand Up @@ -1371,7 +1374,7 @@ def _process_video_input(
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"]
with set_forward_context(None, self.vllm_config):
with self.set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
Expand Down
9 changes: 9 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,15 @@ def check_max_model_len(cls, max_model_len: int) -> int:
"""
return max_model_len

@classmethod
def get_forward_context_manager(cls):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a type hint here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a type hint here?

done.

"""
Returns forward context manager for the current platform.
"""
from vllm.forward_context import set_forward_context

return set_forward_context


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down