-
Notifications
You must be signed in to change notification settings - Fork 454
RPC support for entrypoints (Omni/AsyncOmni) #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
dbcc96c
188d681
aec5487
e9be588
2d17bff
c0b7c43
f961a44
111ca23
66850a7
a4877df
31010c8
786cdc9
2aa65e0
11058c4
475ed7f
a6d1e9e
633ae1c
15efb56
dc49a9f
4d79693
cee0055
d1d6490
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| # ruff: noqa: E402 | ||
| REPO_ROOT = Path(__file__).resolve().parents[2] | ||
| if str(REPO_ROOT) not in sys.path: | ||
| sys.path.insert(0, str(REPO_ROOT)) | ||
|
|
||
| from vllm_omni import Omni | ||
|
|
||
| os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" | ||
|
|
||
| diffusion_models = ["Tongyi-MAI/Z-Image-Turbo"] | ||
|
|
||
| omni_models = ["Qwen/Qwen2.5-Omni-3B"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_name", omni_models) | ||
| def test_omni_model(model_name: str): | ||
| m = Omni(model=model_name, init_timeout=3600) | ||
| results = m.collective_rpc( | ||
| method="sleep", | ||
| args=(1,), | ||
| ) | ||
| assert len(results) == 3 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_name", diffusion_models) | ||
| def test_diffusion_model(model_name: str): | ||
| m = Omni(model=model_name) | ||
| results = m.collective_rpc( | ||
| method="sleep", | ||
| args=(1,), | ||
| ) | ||
| assert len(results) == 1 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,8 +14,10 @@ | |
| import queue | ||
| import sys | ||
| import traceback | ||
| import uuid | ||
| from collections.abc import Callable | ||
| from dataclasses import fields | ||
| from typing import Any | ||
| from typing import Any, TypeVar | ||
|
|
||
| from vllm.inputs import TextPrompt | ||
| from vllm.inputs.preprocess import InputPreprocessor | ||
|
|
@@ -49,6 +51,8 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| _R = TypeVar("_R") | ||
|
|
||
|
|
||
| def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: | ||
| """Build OmniDiffusionConfig kwargs from engine args.""" | ||
|
|
@@ -420,6 +424,67 @@ def process_engine_inputs( | |
| stage_list, engine_input_source, prompt, self.requires_multimodal_data | ||
| ) | ||
|
|
||
| def collective_rpc( | ||
| self, | ||
| method: str | Callable[..., _R], | ||
| timeout: float | None = None, | ||
| args: tuple = (), | ||
| kwargs: dict[str, Any] | None = None, | ||
| ) -> list[_R]: | ||
| """Execute an RPC call on all workers via the stage engine. | ||
|
|
||
| Args: | ||
| method: Name of the worker method to execute, or a callable that | ||
| is serialized and sent to all workers to execute. | ||
|
|
||
| If the method is a callable, it should accept an additional | ||
| `self` argument, in addition to the arguments passed in `args` | ||
| and `kwargs`. The `self` argument will be the worker object. | ||
| timeout: Maximum time in seconds to wait for execution. Raises a | ||
| [`TimeoutError`][] on timeout. `None` means wait indefinitely. | ||
| args: Positional arguments to pass to the worker method. | ||
| kwargs: Keyword arguments to pass to the worker method. | ||
|
|
||
| Returns: | ||
| A list containing the results from each worker. | ||
|
|
||
| Note: | ||
| It is recommended to use this API to only pass control messages, | ||
| and set up data-plane communication to pass data. | ||
| """ | ||
| assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" | ||
|
|
||
| # Submit collective_rpc task to worker | ||
| rpc_id = str(uuid.uuid4()) | ||
| self._in_q.put( | ||
| { | ||
| "type": OmniStageTaskType.COLLECTIVE_RPC, | ||
| "rpc_id": rpc_id, | ||
| "method": method, | ||
| "timeout": timeout, | ||
| "args": args, | ||
| "kwargs": kwargs, | ||
| } | ||
| ) | ||
|
|
||
| # Wait for result from worker | ||
| import time | ||
knlnguyen1802 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| start_time = time.time() | ||
| while True: | ||
| if timeout is not None and (time.time() - start_time) > timeout: | ||
| raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") | ||
|
|
||
| result = self.try_collect() | ||
| if result is not None: | ||
| if result.get("type") == "collective_rpc_result": | ||
| if result.get("rpc_id") == rpc_id: | ||
| if "error" in result: | ||
|
Comment on lines
+521
to
+525
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
While waiting for a matching Useful? React with 👍 / 👎. |
||
| raise RuntimeError(f"collective_rpc failed: {result['error']}") | ||
| return result["result"] | ||
|
|
||
| time.sleep(0.001) # Small sleep to avoid busy waiting | ||
|
|
||
|
|
||
| def _stage_worker( | ||
| model: str, | ||
|
|
@@ -636,6 +701,31 @@ def _stage_worker( | |
| logger.error("Received shutdown signal") | ||
| break | ||
|
|
||
| if task_type == OmniStageTaskType.COLLECTIVE_RPC: | ||
| rpc_id = task.get("rpc_id") | ||
| method = task.get("method") | ||
| timeout = task.get("timeout") | ||
| args = task.get("args") | ||
| kwargs = task.get("kwargs") | ||
| try: | ||
| result = stage_engine.collective_rpc(method, timeout, args, kwargs) | ||
| out_q.put( | ||
| { | ||
| "type": "collective_rpc_result", | ||
| "rpc_id": rpc_id, | ||
| "result": result, | ||
| } | ||
| ) | ||
| except Exception as e: | ||
| out_q.put( | ||
| { | ||
| "type": "collective_rpc_result", | ||
| "rpc_id": rpc_id, | ||
| "error": str(e), | ||
| } | ||
| ) | ||
| continue | ||
|
|
||
| batch_tasks: list[dict[str, Any]] = [task] | ||
| start_time = _time.time() | ||
| if max_batch_size > 1: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.