Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/e2e/test_rpc_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import sys
from pathlib import Path

import pytest

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"

diffusion_models = ["Tongyi-MAI/Z-Image-Turbo"]

omni_models = ["Qwen/Qwen2.5-Omni-3B"]


@pytest.mark.parametrize("model_name", omni_models)
def test_omni_model(model_name: str):
m = Omni(model=model_name, init_timeout=3600)
results = m.collective_rpc(
method="sleep",
args=(1,),
)
assert len(results) == 3


@pytest.mark.parametrize("model_name", diffusion_models)
def test_diffusion_model(model_name: str):
m = Omni(model=model_name)
results = m.collective_rpc(
method="sleep",
args=(1,),
)
assert len(results) == 1
3 changes: 1 addition & 2 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def collective_rpc(
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
unique_reply_rank: int | None = None,
) -> Any:
"""Call a method on worker processes and get results immediately.

Expand All @@ -287,7 +286,6 @@ def collective_rpc(
timeout: Optional timeout in seconds
args: Positional arguments for the method
kwargs: Keyword arguments for the method
unique_reply_rank: If set, only get reply from this rank

Returns:
Single result if unique_reply_rank is provided, otherwise list of results
Expand All @@ -297,6 +295,7 @@ def collective_rpc(

deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}
unique_reply_rank = kwargs.pop("unique_reply_rank", None)

assert isinstance(method, str)
send_method = method
Expand Down
22 changes: 21 additions & 1 deletion vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict
from pprint import pformat
from typing import Any
from typing import Any, TypeVar

from omegaconf import OmegaConf
from tqdm.auto import tqdm
Expand Down Expand Up @@ -42,6 +42,8 @@

logger = init_logger(__name__)

_R = TypeVar("_R")


def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg):
"""Weak reference cleanup function for OmniBase instances."""
Expand Down Expand Up @@ -688,6 +690,24 @@ def _run_generation(
except Exception as e:
logger.exception(f"[{self._name}] Failed to build/log summary: {e}")

def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
results = []
for stage in self.stage_list:
result = stage.collective_rpc(
method=method,
args=args,
timeout=timeout,
kwargs=kwargs,
)
results.append(result)
return results

@property
def _name(self) -> str:
return "Orchestrator"
16 changes: 16 additions & 0 deletions vllm_omni/entrypoints/omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import logging
from collections.abc import Callable
from dataclasses import fields
from typing import Any

from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_file_to_dict
Expand Down Expand Up @@ -115,6 +117,20 @@ def generate(
def _run_engine(self, requests: list[OmniDiffusionRequest]):
return self.engine.step(requests)

def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
) -> Any:
return self.engine.collective_rpc(
method,
timeout=timeout,
args=args,
kwargs=kwargs,
)

def close(self) -> None:
self.engine.close()

Expand Down
92 changes: 91 additions & 1 deletion vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import queue
import sys
import traceback
import uuid
from collections.abc import Callable
from dataclasses import fields
from typing import Any
from typing import Any, TypeVar

from vllm.inputs import TextPrompt
from vllm.inputs.preprocess import InputPreprocessor
Expand Down Expand Up @@ -49,6 +51,8 @@

logger = init_logger(__name__)

_R = TypeVar("_R")


def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]:
"""Build OmniDiffusionConfig kwargs from engine args."""
Expand Down Expand Up @@ -420,6 +424,67 @@ def process_engine_inputs(
stage_list, engine_input_source, prompt, self.requires_multimodal_data
)

def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
"""Execute an RPC call on all workers via the stage engine.

Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.

If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.

Returns:
A list containing the results from each worker.

Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc"

# Submit collective_rpc task to worker
rpc_id = str(uuid.uuid4())
self._in_q.put(
{
"type": OmniStageTaskType.COLLECTIVE_RPC,
"rpc_id": rpc_id,
"method": method,
"timeout": timeout,
"args": args,
"kwargs": kwargs,
}
)

# Wait for result from worker
import time

start_time = time.time()
while True:
if timeout is not None and (time.time() - start_time) > timeout:
raise TimeoutError(f"collective_rpc timed out after {timeout} seconds")

result = self.try_collect()
if result is not None:
if result.get("type") == "collective_rpc_result":
if result.get("rpc_id") == rpc_id:
if "error" in result:
Comment on lines +521 to +525

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge RPC wait loop discards non-RPC outputs

While waiting for a matching collective_rpc_result, the new collective_rpc method pops any item from self._out_q via try_collect() and ignores everything that is not the target RPC result. That silently drops unrelated stage outputs (e.g., pending inference responses or status messages), so issuing a collective RPC while other work is in-flight will lose those messages and leave callers hanging without responses.

Useful? React with 👍 / 👎.

raise RuntimeError(f"collective_rpc failed: {result['error']}")
return result["result"]

time.sleep(0.001) # Small sleep to avoid busy waiting


def _stage_worker(
model: str,
Expand Down Expand Up @@ -636,6 +701,31 @@ def _stage_worker(
logger.error("Received shutdown signal")
break

if task_type == OmniStageTaskType.COLLECTIVE_RPC:
rpc_id = task.get("rpc_id")
method = task.get("method")
timeout = task.get("timeout")
args = task.get("args")
kwargs = task.get("kwargs")
try:
result = stage_engine.collective_rpc(method, timeout, args, kwargs)
out_q.put(
{
"type": "collective_rpc_result",
"rpc_id": rpc_id,
"result": result,
}
)
except Exception as e:
out_q.put(
{
"type": "collective_rpc_result",
"rpc_id": rpc_id,
"error": str(e),
}
)
continue

batch_tasks: list[dict[str, Any]] = [task]
start_time = _time.time()
if max_batch_size > 1:
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/entrypoints/stage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class OmniStageTaskType(enum.Enum):
GENERATE = "generate"
COLLECTIVE_RPC = "collective_rpc"
ABORT = "abort"
SHUTDOWN = "shutdown"

Expand Down