|
1 | | -import sys |
2 | | -import time |
3 | | - |
4 | | -import torch |
5 | | -from openai import OpenAI, OpenAIError |
6 | | - |
7 | | -from vllm import ModelRegistry |
8 | | -from vllm.model_executor.models.opt import OPTForCausalLM |
9 | | -from vllm.model_executor.sampling_metadata import SamplingMetadata |
10 | | -from vllm.utils import get_open_port |
11 | | - |
12 | 1 | from ...utils import VLLM_PATH, RemoteOpenAIServer |
13 | 2 |
|
14 | 3 | chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" |
15 | 4 | assert chatml_jinja_path.exists() |
16 | 5 |
|
17 | 6 |
|
18 | | -class MyOPTForCausalLM(OPTForCausalLM): |
19 | | - |
20 | | - def compute_logits(self, hidden_states: torch.Tensor, |
21 | | - sampling_metadata: SamplingMetadata) -> torch.Tensor: |
22 | | - # this dummy model always predicts the first token |
23 | | - logits = super().compute_logits(hidden_states, sampling_metadata) |
24 | | - logits.zero_() |
25 | | - logits[:, 0] += 1.0 |
26 | | - return logits |
27 | | - |
28 | | - |
29 | | -def server_function(port: int): |
30 | | - # register our dummy model |
31 | | - ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) |
32 | | - |
33 | | - sys.argv = ["placeholder.py"] + [ |
34 | | - "--model", |
35 | | - "facebook/opt-125m", |
| 7 | +def run_and_test_dummy_opt_api_server(model, tp=1): |
| 8 | + # the model is registered through the plugin |
| 9 | + server_args = [ |
36 | 10 | "--gpu-memory-utilization", |
37 | 11 | "0.10", |
38 | 12 | "--dtype", |
39 | 13 | "float32", |
40 | | - "--api-key", |
41 | | - "token-abc123", |
42 | | - "--port", |
43 | | - str(port), |
44 | 14 | "--chat-template", |
45 | 15 | str(chatml_jinja_path), |
| 16 | + "--load-format", |
| 17 | + "dummy", |
| 18 | + "-tp", |
| 19 | + f"{tp}", |
46 | 20 | ] |
47 | | - |
48 | | - import runpy |
49 | | - runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') |
50 | | - |
51 | | - |
52 | | -def test_oot_registration_for_api_server(): |
53 | | - port = get_open_port() |
54 | | - ctx = torch.multiprocessing.get_context() |
55 | | - server = ctx.Process(target=server_function, args=(port, )) |
56 | | - server.start() |
57 | | - |
58 | | - try: |
59 | | - client = OpenAI( |
60 | | - base_url=f"http://localhost:{port}/v1", |
61 | | - api_key="token-abc123", |
| 21 | + with RemoteOpenAIServer(model, server_args) as server: |
| 22 | + client = server.get_client() |
| 23 | + completion = client.chat.completions.create( |
| 24 | + model=model, |
| 25 | + messages=[{ |
| 26 | + "role": "system", |
| 27 | + "content": "You are a helpful assistant." |
| 28 | + }, { |
| 29 | + "role": "user", |
| 30 | + "content": "Hello!" |
| 31 | + }], |
| 32 | + temperature=0, |
62 | 33 | ) |
63 | | - now = time.time() |
64 | | - while True: |
65 | | - try: |
66 | | - completion = client.chat.completions.create( |
67 | | - model="facebook/opt-125m", |
68 | | - messages=[{ |
69 | | - "role": "system", |
70 | | - "content": "You are a helpful assistant." |
71 | | - }, { |
72 | | - "role": "user", |
73 | | - "content": "Hello!" |
74 | | - }], |
75 | | - temperature=0, |
76 | | - ) |
77 | | - break |
78 | | - except OpenAIError as e: |
79 | | - if "Connection error" in str(e): |
80 | | - time.sleep(3) |
81 | | - if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S: |
82 | | - msg = "Server did not start in time" |
83 | | - raise RuntimeError(msg) from e |
84 | | - else: |
85 | | - raise e |
86 | | - finally: |
87 | | - server.terminate() |
| 34 | + generated_text = completion.choices[0].message.content |
| 35 | + assert generated_text is not None |
| 36 | + # make sure only the first token is generated |
| 37 | + rest = generated_text.replace("<s>", "") |
| 38 | + assert rest == "" |
| 39 | + |
88 | 40 |
|
89 | | - generated_text = completion.choices[0].message.content |
90 | | - assert generated_text is not None |
91 | | - # make sure only the first token is generated |
92 | | - # TODO(youkaichao): Fix the test with plugin |
93 | | - rest = generated_text.replace("<s>", "") # noqa |
94 | | - # assert rest == "" |
| 41 | +def test_oot_registration_for_api_server(dummy_opt_path: str): |
| 42 | + run_and_test_dummy_opt_api_server(dummy_opt_path) |
0 commit comments