diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 09641aaff306..d7d6419d643b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -43,11 +43,12 @@ # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypedDict +from vllm import envs from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict -from vllm.multimodal.utils import MediaConnector +from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -806,7 +807,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker multimodal_config = self._tracker.model_config.multimodal_config media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) - self._connector = MediaConnector( + + self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( + envs.VLLM_MEDIA_CONNECTOR, media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, allowed_media_domains=tracker.allowed_media_domains, @@ -891,7 +894,8 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: self._tracker = tracker multimodal_config = self._tracker.model_config.multimodal_config media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) - self._connector = MediaConnector( + self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( + envs.VLLM_MEDIA_CONNECTOR, media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, allowed_media_domains=tracker.allowed_media_domains, diff --git a/vllm/envs.py b/vllm/envs.py index 81f189ada9a6..dc919e774196 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -70,6 +70,7 @@ VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" + VLLM_MEDIA_CONNECTOR: str = "http" VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.8" @@ -738,6 +739,14 @@ def get_vllm_port() -> int | None: "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv( "VLLM_VIDEO_LOADER_BACKEND", "opencv" ), + # Media connector implementation. + # - "http": Default connector that supports fetching media via HTTP. + # + # Custom implementations can be registered + # via `@MEDIA_CONNECTOR_REGISTRY.register("my_custom_media_connector")` and + # imported at runtime. + # If a non-existing backend is used, an AssertionError will be thrown. + "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache # Default is 4 GiB per API process + 4 GiB per engine core process "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 7f259dad08f9..3fad11a2cb4d 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -20,6 +20,7 @@ from vllm.connections import HTTPConnection, global_http_connection from vllm.logger import init_logger from vllm.utils.jsontree import json_map_leaves +from vllm.utils.registry import ExtensionManager from .audio import AudioMediaIO from .base import MediaIO @@ -46,7 +47,10 @@ _M = TypeVar("_M") +MEDIA_CONNECTOR_REGISTRY = ExtensionManager() + +@MEDIA_CONNECTOR_REGISTRY.register("http") class MediaConnector: def __init__( self, diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 666ef275a924..369c5e6cb4d1 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -14,6 +14,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.utils.registry import ExtensionManager from .base import MediaIO from .image import ImageMediaIO @@ -63,25 +64,7 @@ def load_bytes( raise NotImplementedError -class VideoLoaderRegistry: - def __init__(self) -> None: - self.name2class: dict[str, type] = {} - - def register(self, name: str): - def wrap(cls_to_register): - self.name2class[name] = cls_to_register - return cls_to_register - - return wrap - - @staticmethod - def load(cls_name: str) -> VideoLoader: - cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name) - assert cls is not None, f"VideoLoader class {cls_name} not found" - return cls() - - -VIDEO_LOADER_REGISTRY = VideoLoaderRegistry() +VIDEO_LOADER_REGISTRY = ExtensionManager() @VIDEO_LOADER_REGISTRY.register("opencv") diff --git a/vllm/utils/registry.py b/vllm/utils/registry.py new file mode 100644 index 000000000000..ac9b859159ea --- /dev/null +++ b/vllm/utils/registry.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + + +class ExtensionManager: + """ + A registry for managing pluggable extension classes. + + This class provides a simple mechanism to register and instantiate + extension classes by name. It is commonly used to implement plugin + systems where different implementations can be swapped at runtime. + + Examples: + Basic usage with a registry instance: + + >>> FOO_REGISTRY = ExtensionManager() + >>> @FOO_REGISTRY.register("my_foo_impl") + ... class MyFooImpl(Foo): + ... def __init__(self, value): + ... self.value = value + >>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123) + + """ + + def __init__(self) -> None: + """ + Initialize an empty extension registry. + """ + self.name2class: dict[str, type] = {} + + def register(self, name: str): + """ + Decorator to register a class with the given name. + """ + + def wrap(cls_to_register): + self.name2class[name] = cls_to_register + return cls_to_register + + return wrap + + def load(self, cls_name: str, *args, **kwargs) -> Any: + """ + Instantiate and return a registered extension class by name. + """ + cls = self.name2class.get(cls_name) + assert cls is not None, f"Extension class {cls_name} not found" + return cls(*args, **kwargs)