Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import multiprocessing as mp
import pickle
import time
from typing import Any, Callable

import cloudpickle
from vllm.logger import init_logger

from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, OmniDiffusionConfig
Expand Down Expand Up @@ -151,16 +154,90 @@ def _launch_workers(self, broadcast_handle):
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
return scheduler.add_req(requests)

def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
unique_reply_rank: int | None = None,
) -> Any:
"""Call a method on worker processes and get results immediately.

Args:
method: The method name (str) or callable to execute on workers
timeout: Optional timeout in seconds
args: Positional arguments for the method
kwargs: Keyword arguments for the method
unique_reply_rank: If set, only get reply from this rank

Returns:
Single result if unique_reply_rank is provided, otherwise list of results
"""
if self._closed:
raise RuntimeError("DiffusionEngine is closed.")

deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}

# Prepare the method to send
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)

# Prepare RPC request message
rpc_request = {
"type": "rpc",
"method": send_method,
"args": args,
"kwargs": kwargs,
"output_rank": unique_reply_rank,
}

try:
# Broadcast RPC request to all workers via unified message queue
scheduler.mq.enqueue(rpc_request)

# Determine which workers we expect responses from
num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus

responses = []
for _ in range(num_responses):
dequeue_timeout = None if deadline is None else (deadline - time.monotonic())
try:
if scheduler.result_mq is None:
raise RuntimeError("Result queue not initialized")

response = scheduler.result_mq.dequeue(timeout=dequeue_timeout)

# Check if response indicates an error
if isinstance(response, dict) and response.get("status") == "error":
raise RuntimeError(
f"Worker failed with error '{response.get('error')}', "
"please check the stack trace above for the root cause"
)

responses.append(response)
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e

return responses[0] if unique_reply_rank is not None else responses

except Exception as e:
logger.error(f"RPC call failed: {e}")
raise

def close(self, *, timeout_s: float = 30.0) -> None:
if self._closed:
return
self._closed = True

# Send shutdown signal to worker processes via broadcast queue
# Send shutdown signal to worker processes via unified broadcast queue
try:
if getattr(scheduler, "mq", None) is not None:
for _ in range(self.od_config.num_gpus or 1):
scheduler.mq.enqueue(SHUTDOWN_MESSAGE)
scheduler.mq.enqueue({"type": "shutdown"})
except Exception as exc: # pragma: no cover - best effort cleanup
logger.warning("Failed to send shutdown signal: %s", exc)

Expand Down
26 changes: 21 additions & 5 deletions vllm_omni/diffusion/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import cloudpickle
import pickle
import zmq
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
Expand Down Expand Up @@ -28,7 +30,7 @@ def initialize(self, od_config: OmniDiffusionConfig):
self.od_config = od_config
self.context = zmq.Context() # Standard synchronous context

# Initialize MessageQueue for broadcasting requests
# Initialize single MessageQueue for all message types (generation & RPC)
# Assuming all readers are local for now as per current launch_engine implementation
self.mq = MessageQueue(
n_reader=od_config.num_gpus,
Expand All @@ -44,15 +46,29 @@ def initialize_result_queue(self, handle):
self.result_mq = MessageQueue.create_from_handle(handle, rank=0)
logger.info("SyncScheduler initialized result MessageQueue")

def initialize_rpc_result_queue(self, handle):
# Deprecated: RPC results now use the same result queue
logger.info("RPC results use the unified result queue")

def get_broadcast_handle(self):
return self.mq.export_handle()

def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""Sends a request to the scheduler and waits for the response."""
"""Sends a generation request via RPC to worker rank 0 and waits for the response."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain that why we send to rank0 other than broadcast to all workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs are somewhat mistake when I first write it. Mistake on understanding how Scheduler work. Let's me fix it.
It should broadcast to all worker

try:
# Broadcast request to all workers
self.mq.enqueue(requests)
# Wait for result from Rank 0 (or whoever sends it)
# Prepare RPC request for generation
rpc_request = {
"type": "rpc",
"method": "generate",
"args": (requests,),
"kwargs": {},
"output_rank": 0, # Only rank 0 replies
}

# Broadcast RPC request to all workers
self.mq.enqueue(rpc_request)

# Wait for result from Rank 0
if self.result_mq is None:
raise RuntimeError("Result queue not initialized")

Expand Down
125 changes: 98 additions & 27 deletions vllm_omni/diffusion/worker/gpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import cloudpickle
import multiprocessing as mp
import os
import pickle
import time

import torch
Expand Down Expand Up @@ -91,6 +93,29 @@ def init_device_and_model(self) -> None:
if self.cache_backend is not None:
self.cache_backend.enable(self.pipeline)


def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
"""
Generate output for the given requests.

Args:
requests: List of diffusion requests

Returns:
DiffusionOutput with generated results
"""
return self.execute_model(requests, self.od_config)

def do_shutdown(self) -> str:
"""
Shutdown the worker gracefully.

Returns:
Confirmation message
"""
self.shutdown()
return f"Worker {self.rank} shutdown complete"

@torch.inference_mode()
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
"""
Expand Down Expand Up @@ -130,7 +155,7 @@ def __init__(
# Inter-process Communication
self.context = zmq.Context(io_threads=2)

# Initialize MessageQueue reader from handle
# Initialize MessageQueue reader from handle (unified for generation & RPC)
self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id)

self.result_mq = None
Expand Down Expand Up @@ -162,55 +187,101 @@ def return_result(self, output: DiffusionOutput):
if self.result_mq is not None:
self.result_mq.enqueue(output)

def recv_reqs(self):
def recv_message(self):
"""
Receive requests from broadcast queue
Receive unified messages (RPC requests, shutdown) from broadcast queue.
Uses indefinite=True to block until a message arrives.
"""
return self.mq.dequeue(indefinite=True)

def execute_rpc(self, rpc_request: dict):
"""Execute an RPC request and return the result."""
try:
method = rpc_request["method"]
args = rpc_request.get("args", ())
kwargs = rpc_request.get("kwargs", {})
output_rank = rpc_request.get("output_rank")

# Only execute if we should reply (either output_rank is None or matches our rank)
if output_rank is not None and output_rank != self.gpu_id:
return None

# Deserialize method if it's a callable
if isinstance(method, bytes):
method = cloudpickle.loads(method)

# Execute the method
if isinstance(method, str):
# Method is a string, call it on the worker
func = getattr(self.worker, method)
result = func(*args, **kwargs)
else:
# Method is a callable
result = method(self.worker, *args, **kwargs)

return result
except Exception as e:
logger.error(f"Error executing RPC: {e}", exc_info=True)
return {"status": "error", "error": str(e)}

# TODO: queueing, cancellation
def worker_busy_loop(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should redesign this, because it has become more complex.

we can address the redesign in a separate, follow-up task if you don't have time.

here is a good example of worker_busy_loop: https://github.com/vllm-project/vllm/blob/c02a2705f9ceeb00b5d32453621f997b2ceafbea/vllm/v1/executor/multiproc_executor.py#L806

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with it, the redesign is WIP and will need a more structure RFC.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming — is this already WIP?

Copy link
Contributor Author

@knlnguyen1802 knlnguyen1802 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is not WIP. But the redesign as you said above is on working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 It's ready now. Could you take a look again thanks ?

"""Main busy loop for Multiprocessing Workers"""

logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory")

while self._running:
reqs = None
# 1: receive requests
# Receive unified message (generation request, RPC request, or shutdown)
msg = None
try:
reqs = self.recv_reqs()
msg = self.recv_message()
except Exception as e:
logger.error(
f"Error receiving requests in scheduler event loop: {e}",
f"Error receiving message in worker loop: {e}",
exc_info=True,
)
continue

if reqs == SHUTDOWN_MESSAGE:
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue
if reqs is None:
if msg is None:
logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id)
continue

# 2: execute, make sure a reply is always sent
try:
output = self.worker.execute_model(reqs, self.od_config)
except Exception as e:
logger.error(
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))

try:
self.return_result(output)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
# Route message based on type
if isinstance(msg, dict) and msg.get("type") == "rpc":
# Handle RPC request
try:
result = self.execute_rpc(msg)
if result is not None and self.gpu_id == 0:
self.return_result(result)
Comment on lines 245 to 248

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge collective_rpc waits for replies other ranks never send

collective_rpc expects a reply from each worker unless unique_reply_rank is set, but in worker_busy_loop only rank 0 enqueues RPC responses (self.gpu_id == 0 gate) and other ranks drop their results because they lack a result queue. On multi-GPU runs any RPC targeting a non-zero rank or broadcast calls with unique_reply_rank=None will block/time out waiting for responses that are never sent.

Useful? React with 👍 / 👎.

except Exception as e:
logger.error(f"Error processing RPC: {e}", exc_info=True)
if self.gpu_id == 0:
self.return_result({"status": "error", "error": str(e)})

elif isinstance(msg, dict) and msg.get("type") == "shutdown":
# Handle shutdown message
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue

else:
# Handle generation request (OmniDiffusionRequest list)
try:
output = self.worker.execute_model(msg, self.od_config)
except Exception as e:
logger.error(
f"Error executing forward in event loop: {e}",
exc_info=True,
)
output = DiffusionOutput(error=str(e))

try:
self.return_result(output)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
continue

logger.info("event loop terminated.")
try:
self.worker.shutdown()
Expand Down
Loading