Skip to content

Commit c725e97

Browse files
youkaichaoLeiWang1999
authored andcommitted
[misc][plugin] add plugin system implementation (vllm-project#7426)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 5da5fe9 commit c725e97

File tree

13 files changed

+161
-101
lines changed

13 files changed

+161
-101
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ steps:
7777
- pytest -v -s core
7878

7979
- label: Entrypoints Test # 20min
80+
working_dir: "/vllm-workspace/tests"
8081
fast_check: true
8182
mirror_hardwares: [amd]
8283
source_file_dependencies:
8384
- vllm/
8485
commands:
86+
- pip install -e ./plugins/vllm_add_dummy_model
8587
- pytest -v -s entrypoints/llm
8688
- pytest -v -s entrypoints/openai
8789

@@ -154,6 +156,7 @@ steps:
154156
- vllm/
155157
- tests/models
156158
commands:
159+
- pip install -e ./plugins/vllm_add_dummy_model
157160
- pytest -v -s models -m \"not vlm\"
158161

159162
- label: Vision Language Models Test # 42min
@@ -289,6 +292,7 @@ steps:
289292
- pytest -v -s distributed/test_chunked_prefill_distributed.py
290293
- pytest -v -s distributed/test_multimodal_broadcast.py
291294
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
295+
- pytest -v -s distributed/test_distributed_oot.py
292296
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
293297
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
294298

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ pyzmq
2323
librosa # Required for audio processing
2424
soundfile # Required for audio processing
2525
gguf == 0.9.1
26+
importlib_metadata
2627
compressed-tensors == 0.5.0

tests/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import contextlib
22
import gc
3+
import json
34
import os
45
import sys
6+
import tempfile
57
from collections import UserList
68
from enum import Enum
79
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
@@ -11,6 +13,7 @@
1113
import torch
1214
import torch.nn as nn
1315
import torch.nn.functional as F
16+
from huggingface_hub import snapshot_download
1417
from PIL import Image
1518
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
1619
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
@@ -757,3 +760,26 @@ def num_gpus_available():
757760
in current process."""
758761

759762
return cuda_device_count_stateless()
763+
764+
765+
temp_dir = tempfile.gettempdir()
766+
_dummy_path = os.path.join(temp_dir, "dummy_opt")
767+
768+
769+
@pytest.fixture
770+
def dummy_opt_path():
771+
json_path = os.path.join(_dummy_path, "config.json")
772+
if not os.path.exists(_dummy_path):
773+
snapshot_download(repo_id="facebook/opt-125m",
774+
local_dir=_dummy_path,
775+
ignore_patterns=[
776+
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
777+
"*.msgpack"
778+
])
779+
assert os.path.exists(json_path)
780+
with open(json_path, "r") as f:
781+
config = json.load(f)
782+
config["architectures"] = ["MyOPTForCausalLM"]
783+
with open(json_path, "w") as f:
784+
json.dump(config, f)
785+
return _dummy_path
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ..entrypoints.openai.test_oot_registration import (
2+
run_and_test_dummy_opt_api_server)
3+
4+
5+
def test_distributed_oot(dummy_opt_path: str):
6+
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)
Lines changed: 27 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,42 @@
1-
import sys
2-
import time
3-
4-
import torch
5-
from openai import OpenAI, OpenAIError
6-
7-
from vllm import ModelRegistry
8-
from vllm.model_executor.models.opt import OPTForCausalLM
9-
from vllm.model_executor.sampling_metadata import SamplingMetadata
10-
from vllm.utils import get_open_port
11-
121
from ...utils import VLLM_PATH, RemoteOpenAIServer
132

143
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
154
assert chatml_jinja_path.exists()
165

176

18-
class MyOPTForCausalLM(OPTForCausalLM):
19-
20-
def compute_logits(self, hidden_states: torch.Tensor,
21-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
22-
# this dummy model always predicts the first token
23-
logits = super().compute_logits(hidden_states, sampling_metadata)
24-
logits.zero_()
25-
logits[:, 0] += 1.0
26-
return logits
27-
28-
29-
def server_function(port: int):
30-
# register our dummy model
31-
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
32-
33-
sys.argv = ["placeholder.py"] + [
34-
"--model",
35-
"facebook/opt-125m",
7+
def run_and_test_dummy_opt_api_server(model, tp=1):
8+
# the model is registered through the plugin
9+
server_args = [
3610
"--gpu-memory-utilization",
3711
"0.10",
3812
"--dtype",
3913
"float32",
40-
"--api-key",
41-
"token-abc123",
42-
"--port",
43-
str(port),
4414
"--chat-template",
4515
str(chatml_jinja_path),
16+
"--load-format",
17+
"dummy",
18+
"-tp",
19+
f"{tp}",
4620
]
47-
48-
import runpy
49-
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
50-
51-
52-
def test_oot_registration_for_api_server():
53-
port = get_open_port()
54-
ctx = torch.multiprocessing.get_context()
55-
server = ctx.Process(target=server_function, args=(port, ))
56-
server.start()
57-
58-
try:
59-
client = OpenAI(
60-
base_url=f"http://localhost:{port}/v1",
61-
api_key="token-abc123",
21+
with RemoteOpenAIServer(model, server_args) as server:
22+
client = server.get_client()
23+
completion = client.chat.completions.create(
24+
model=model,
25+
messages=[{
26+
"role": "system",
27+
"content": "You are a helpful assistant."
28+
}, {
29+
"role": "user",
30+
"content": "Hello!"
31+
}],
32+
temperature=0,
6233
)
63-
now = time.time()
64-
while True:
65-
try:
66-
completion = client.chat.completions.create(
67-
model="facebook/opt-125m",
68-
messages=[{
69-
"role": "system",
70-
"content": "You are a helpful assistant."
71-
}, {
72-
"role": "user",
73-
"content": "Hello!"
74-
}],
75-
temperature=0,
76-
)
77-
break
78-
except OpenAIError as e:
79-
if "Connection error" in str(e):
80-
time.sleep(3)
81-
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
82-
msg = "Server did not start in time"
83-
raise RuntimeError(msg) from e
84-
else:
85-
raise e
86-
finally:
87-
server.terminate()
34+
generated_text = completion.choices[0].message.content
35+
assert generated_text is not None
36+
# make sure only the first token is generated
37+
rest = generated_text.replace("<s>", "")
38+
assert rest == ""
39+
8840

89-
generated_text = completion.choices[0].message.content
90-
assert generated_text is not None
91-
# make sure only the first token is generated
92-
# TODO(youkaichao): Fix the test with plugin
93-
rest = generated_text.replace("<s>", "") # noqa
94-
# assert rest == ""
41+
def test_oot_registration_for_api_server(dummy_opt_path: str):
42+
run_and_test_dummy_opt_api_server(dummy_opt_path)

tests/models/test_oot_registration.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,27 @@
1-
from typing import Optional
1+
import os
22

3-
import torch
3+
import pytest
44

5-
from vllm import LLM, ModelRegistry, SamplingParams
6-
from vllm.model_executor.models.opt import OPTForCausalLM
7-
from vllm.model_executor.sampling_metadata import SamplingMetadata
5+
from vllm import LLM, SamplingParams
86

7+
# NOTE: the order of the tests is important
8+
# the first test does not load any plugins
9+
# the second test loads the plugin
10+
# they share the same process, so the plugin is loaded for the second test
911

10-
class MyOPTForCausalLM(OPTForCausalLM):
1112

12-
def compute_logits(
13-
self,
14-
hidden_states: torch.Tensor,
15-
sampling_metadata: SamplingMetadata,
16-
) -> Optional[torch.Tensor]:
17-
# this dummy model always predicts the first token
18-
logits = super().compute_logits(hidden_states, sampling_metadata)
19-
logits.zero_()
20-
logits[:, 0] += 1.0
21-
return logits
13+
def test_plugin(dummy_opt_path):
14+
os.environ["VLLM_PLUGINS"] = ""
15+
with pytest.raises(Exception) as excinfo:
16+
LLM(model=dummy_opt_path, load_format="dummy")
17+
assert "are not supported for now" in str(excinfo.value)
2218

2319

24-
def test_oot_registration():
25-
# register our dummy model
26-
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
20+
def test_oot_registration(dummy_opt_path):
21+
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
2722
prompts = ["Hello, my name is", "The text does not matter"]
2823
sampling_params = SamplingParams(temperature=0)
29-
llm = LLM(model="facebook/opt-125m")
24+
llm = LLM(model=dummy_opt_path, load_format="dummy")
3025
first_token = llm.get_tokenizer().decode(0)
3126
outputs = llm.generate(prompts, sampling_params)
3227

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from setuptools import setup
2+
3+
setup(name='vllm_add_dummy_model',
4+
version='0.1',
5+
packages=['vllm_add_dummy_model'],
6+
entry_points={
7+
'vllm.general_plugins':
8+
["register_dummy_model = vllm_add_dummy_model:register"]
9+
})
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from vllm import ModelRegistry
6+
from vllm.model_executor.models.opt import OPTForCausalLM
7+
from vllm.model_executor.sampling_metadata import SamplingMetadata
8+
9+
10+
class MyOPTForCausalLM(OPTForCausalLM):
11+
12+
def compute_logits(
13+
self, hidden_states: torch.Tensor,
14+
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
15+
# this dummy model always predicts the first token
16+
logits = super().compute_logits(hidden_states, sampling_metadata)
17+
if logits is not None:
18+
logits.zero_()
19+
logits[:, 0] += 1.0
20+
return logits
21+
22+
23+
def register():
24+
# register our dummy model
25+
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
26+
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)

vllm/engine/llm_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def __init__(
227227
)
228228
# TODO(woosuk): Print more configs in debug mode.
229229

230+
from vllm.plugins import load_general_plugins
231+
load_general_plugins()
232+
230233
self.model_config = model_config
231234
self.cache_config = cache_config
232235
self.lora_config = lora_config

vllm/envs.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import tempfile
3-
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
3+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
44

55
if TYPE_CHECKING:
66
VLLM_HOST_IP: str = ""
@@ -55,6 +55,7 @@
5555
VERBOSE: bool = False
5656
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
5757
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
58+
VLLM_PLUGINS: Optional[List[str]] = None
5859

5960

6061
def get_default_cache_root():
@@ -362,6 +363,13 @@ def get_default_config_root():
362363
lambda:
363364
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
364365
("1", "true")),
366+
367+
# a list of plugin names to load, separated by commas.
368+
# if this is not set, it means all plugins will be loaded
369+
# if this is set to an empty string, no plugins will be loaded
370+
"VLLM_PLUGINS":
371+
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
372+
"VLLM_PLUGINS"].split(","),
365373
}
366374

367375
# end-env-vars-definition

0 commit comments

Comments
 (0)