From 174d0202a340bc7742f7aa568568a3cf0c0f15d2 Mon Sep 17 00:00:00 2001 From: NadavShmayo Date: Mon, 12 Aug 2024 22:04:08 +0300 Subject: [PATCH 1/4] Add ModelPlugin implementation --- vllm/plugins/__init__.py | 0 vllm/plugins/model_plugin.py | 67 ++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 vllm/plugins/__init__.py create mode 100644 vllm/plugins/model_plugin.py diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/plugins/model_plugin.py b/vllm/plugins/model_plugin.py new file mode 100644 index 000000000000..1de7696f6d25 --- /dev/null +++ b/vllm/plugins/model_plugin.py @@ -0,0 +1,67 @@ +from typing import Optional, Iterable, Tuple, List, Type + +from transformers import PretrainedConfig +from importlib.metadata import entry_points + +from vllm import ModelRegistry +from dataclasses import dataclass +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from abc import ABC, abstractmethod + +from vllm.sequence import IntermediateTensors + + +logger = init_logger(__name__) + + +class ModelArchitectureBase(torch.nn.Module, ABC): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + + @abstractmethod + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + pass + + @abstractmethod + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + pass + + +@dataclass +class ModelPlugin: + architecture_name: str + implementation_cls: Type[ModelArchitectureBase] + + +def load_model_plugins(): + for entry_point in entry_points().select(group="vllm.model_architectures"): + logger.debug(f"Loading model architecture plugin {entry_point.name}") + model_architecture_plugin = entry_point.load() + if not isinstance(model_architecture_plugin, ModelPlugin): + raise ValueError( + f"Model architecture plugin must be an instance of ModelPlugin, got {model_architecture_plugin}" + ) + + ModelRegistry.register_model( + model_architecture_plugin.architecture_name, + model_architecture_plugin.implementation_cls, + ) From edf8ec8e25dd268c52d038ff203cecb34d60b22a Mon Sep 17 00:00:00 2001 From: NadavShmayo Date: Mon, 12 Aug 2024 22:06:49 +0300 Subject: [PATCH 2/4] Add support for model plugins in openai server --- vllm/entrypoints/openai/api_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1a0addfedc55..6066cb2464fe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -42,6 +42,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger +from vllm.plugins.model_plugin import load_model_plugins from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path from vllm.version import __version__ as VLLM_VERSION @@ -352,6 +353,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + load_model_plugins() async with build_async_engine_client(args) as async_engine_client: app = await init_app(async_engine_client, args) From 590fd90cc09f51723382c11a383e940e0fb46504 Mon Sep 17 00:00:00 2001 From: NadavShmayo Date: Mon, 12 Aug 2024 22:08:25 +0300 Subject: [PATCH 3/4] Add model architecture override cli option --- vllm/engine/arg_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4a3fae73824..8280fe8fb212 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -121,6 +121,8 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None + model_architecture_override: Optional[str] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -674,6 +676,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "modules. This involves use of possibly costly and or blocking " "operations and hence might have a performance impact.") + parser.add_argument( + '--model-architecture-override', + type=str, + default=None, + help= + 'Override the model architecture to use - useful for using plugins.' + ) + return parser @classmethod @@ -735,6 +745,11 @@ def create_engine_config(self, ) -> EngineConfig: skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, multimodal_config=multimodal_config) + if self.model_architecture_override: + model_config.hf_config.architectures = [ + self.model_architecture_override + ] + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, From edf79896c89263ff6cb0ff6aa69b8b8a693ff487 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Mon, 12 Aug 2024 22:33:30 +0300 Subject: [PATCH 4/4] Fix formatting --- vllm/plugins/model_plugin.py | 40 +++++++++++++++++------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/vllm/plugins/model_plugin.py b/vllm/plugins/model_plugin.py index 1de7696f6d25..d669dfffef82 100644 --- a/vllm/plugins/model_plugin.py +++ b/vllm/plugins/model_plugin.py @@ -1,30 +1,28 @@ -from typing import Optional, Iterable, Tuple, List, Type - -from transformers import PretrainedConfig +from abc import ABC, abstractmethod +from dataclasses import dataclass from importlib.metadata import entry_points +from typing import Iterable, List, Optional, Tuple, Type -from vllm import ModelRegistry -from dataclasses import dataclass import torch +from transformers import PretrainedConfig +from vllm import ModelRegistry from vllm.attention import AttentionMetadata from vllm.config import CacheConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from abc import ABC, abstractmethod - from vllm.sequence import IntermediateTensors - logger = init_logger(__name__) class ModelArchitectureBase(torch.nn.Module, ABC): + def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ): self.config = config self.cache_config = cache_config @@ -36,12 +34,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @abstractmethod def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: pass @@ -54,12 +52,12 @@ class ModelPlugin: def load_model_plugins(): for entry_point in entry_points().select(group="vllm.model_architectures"): - logger.debug(f"Loading model architecture plugin {entry_point.name}") + logger.debug("Loading model architecture plugin %s", entry_point.name) model_architecture_plugin = entry_point.load() if not isinstance(model_architecture_plugin, ModelPlugin): raise ValueError( - f"Model architecture plugin must be an instance of ModelPlugin, got {model_architecture_plugin}" - ) + f"Model architecture plugin must be an instance of " + f"ModelPlugin, got {model_architecture_plugin}") ModelRegistry.register_model( model_architecture_plugin.architecture_name,