Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None

model_architecture_override: Optional[str] = None

def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
Expand Down Expand Up @@ -674,6 +676,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")

parser.add_argument(
'--model-architecture-override',
type=str,
default=None,
help=
'Override the model architecture to use - useful for using plugins.'
)

return parser

@classmethod
Expand Down Expand Up @@ -735,6 +745,11 @@ def create_engine_config(self, ) -> EngineConfig:
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
multimodal_config=multimodal_config)
if self.model_architecture_override:
model_config.hf_config.architectures = [
self.model_architecture_override
]

cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.plugins.model_plugin import load_model_plugins
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION
Expand Down Expand Up @@ -352,6 +353,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

load_model_plugins()
async with build_async_engine_client(args) as async_engine_client:
app = await init_app(async_engine_client, args)

Expand Down
Empty file added vllm/plugins/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions vllm/plugins/model_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from importlib.metadata import entry_points
from typing import Iterable, List, Optional, Tuple, Type

import torch
from transformers import PretrainedConfig

from vllm import ModelRegistry
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)


class ModelArchitectureBase(torch.nn.Module, ABC):

def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config

@abstractmethod
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
pass

@abstractmethod
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
pass


@dataclass
class ModelPlugin:
architecture_name: str
implementation_cls: Type[ModelArchitectureBase]


def load_model_plugins():
for entry_point in entry_points().select(group="vllm.model_architectures"):
logger.debug("Loading model architecture plugin %s", entry_point.name)
model_architecture_plugin = entry_point.load()
if not isinstance(model_architecture_plugin, ModelPlugin):
raise ValueError(
f"Model architecture plugin must be an instance of "
f"ModelPlugin, got {model_architecture_plugin}")

ModelRegistry.register_model(
model_architecture_plugin.architecture_name,
model_architecture_plugin.implementation_cls,
)