diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4a3fae73824..8280fe8fb212 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -121,6 +121,8 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None + model_architecture_override: Optional[str] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -674,6 +676,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "modules. This involves use of possibly costly and or blocking " "operations and hence might have a performance impact.") + parser.add_argument( + '--model-architecture-override', + type=str, + default=None, + help= + 'Override the model architecture to use - useful for using plugins.' + ) + return parser @classmethod @@ -735,6 +745,11 @@ def create_engine_config(self, ) -> EngineConfig: skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, multimodal_config=multimodal_config) + if self.model_architecture_override: + model_config.hf_config.architectures = [ + self.model_architecture_override + ] + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1a0addfedc55..6066cb2464fe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -42,6 +42,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger +from vllm.plugins.model_plugin import load_model_plugins from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path from vllm.version import __version__ as VLLM_VERSION @@ -352,6 +353,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + load_model_plugins() async with build_async_engine_client(args) as async_engine_client: app = await init_app(async_engine_client, args) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/plugins/model_plugin.py b/vllm/plugins/model_plugin.py new file mode 100644 index 000000000000..d669dfffef82 --- /dev/null +++ b/vllm/plugins/model_plugin.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from importlib.metadata import entry_points +from typing import Iterable, List, Optional, Tuple, Type + +import torch +from transformers import PretrainedConfig + +from vllm import ModelRegistry +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class ModelArchitectureBase(torch.nn.Module, ABC): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + + @abstractmethod + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + pass + + @abstractmethod + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + pass + + +@dataclass +class ModelPlugin: + architecture_name: str + implementation_cls: Type[ModelArchitectureBase] + + +def load_model_plugins(): + for entry_point in entry_points().select(group="vllm.model_architectures"): + logger.debug("Loading model architecture plugin %s", entry_point.name) + model_architecture_plugin = entry_point.load() + if not isinstance(model_architecture_plugin, ModelPlugin): + raise ValueError( + f"Model architecture plugin must be an instance of " + f"ModelPlugin, got {model_architecture_plugin}") + + ModelRegistry.register_model( + model_architecture_plugin.architecture_name, + model_architecture_plugin.implementation_cls, + )