Skip to content

Commit 1c5b276

Browse files
knlnguyen1802ZJY0516
authored andcommitted
RPC support for OmniDiffusion (vllm-project#371)
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: wangyu31577 <wangyu31577@hundsun.com>
1 parent b82519d commit 1c5b276

3 files changed

Lines changed: 168 additions & 31 deletions

File tree

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import multiprocessing as mp
55
import time
66
import weakref
7+
from collections.abc import Callable
78
from dataclasses import dataclass
9+
from typing import Any
810

911
from vllm.logger import init_logger
1012

@@ -195,6 +197,77 @@ def _launch_workers(self, broadcast_handle):
195197
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
196198
return scheduler.add_req(requests)
197199

200+
def collective_rpc(
201+
self,
202+
method: str | Callable,
203+
timeout: float | None = None,
204+
args: tuple = (),
205+
kwargs: dict | None = None,
206+
unique_reply_rank: int | None = None,
207+
) -> Any:
208+
"""Call a method on worker processes and get results immediately.
209+
210+
Args:
211+
method: The method name (str) or callable to execute on workers
212+
timeout: Optional timeout in seconds
213+
args: Positional arguments for the method
214+
kwargs: Keyword arguments for the method
215+
unique_reply_rank: If set, only get reply from this rank
216+
217+
Returns:
218+
Single result if unique_reply_rank is provided, otherwise list of results
219+
"""
220+
if self._closed:
221+
raise RuntimeError("DiffusionEngine is closed.")
222+
223+
deadline = None if timeout is None else time.monotonic() + timeout
224+
kwargs = kwargs or {}
225+
226+
assert isinstance(method, str)
227+
send_method = method
228+
229+
# Prepare RPC request message
230+
rpc_request = {
231+
"type": "rpc",
232+
"method": send_method,
233+
"args": args,
234+
"kwargs": kwargs,
235+
"output_rank": unique_reply_rank,
236+
}
237+
238+
try:
239+
# Broadcast RPC request to all workers via unified message queue
240+
scheduler.mq.enqueue(rpc_request)
241+
242+
# Determine which workers we expect responses from
243+
num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus
244+
245+
responses = []
246+
for _ in range(num_responses):
247+
dequeue_timeout = None if deadline is None else (deadline - time.monotonic())
248+
try:
249+
if scheduler.result_mq is None:
250+
raise RuntimeError("Result queue not initialized")
251+
252+
response = scheduler.result_mq.dequeue(timeout=dequeue_timeout)
253+
254+
# Check if response indicates an error
255+
if isinstance(response, dict) and response.get("status") == "error":
256+
raise RuntimeError(
257+
f"Worker failed with error '{response.get('error')}', "
258+
"please check the stack trace above for the root cause"
259+
)
260+
261+
responses.append(response)
262+
except TimeoutError as e:
263+
raise TimeoutError(f"RPC call to {method} timed out.") from e
264+
265+
return responses[0] if unique_reply_rank is not None else responses
266+
267+
except Exception as e:
268+
logger.error(f"RPC call failed: {e}")
269+
raise
270+
198271
def _dummy_run(self):
199272
"""A dummy run to warm up the model."""
200273
prompt = "dummy run"

vllm_omni/diffusion/scheduler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def initialize(self, od_config: OmniDiffusionConfig):
2929
self.od_config = od_config
3030
self.context = zmq.Context() # Standard synchronous context
3131

32-
# Initialize MessageQueue for broadcasting requests
32+
# Initialize single MessageQueue for all message types (generation & RPC)
3333
# Assuming all readers are local for now as per current launch_engine implementation
3434
self.mq = MessageQueue(
3535
n_reader=self.num_workers,
@@ -51,9 +51,20 @@ def get_broadcast_handle(self):
5151
def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
5252
"""Sends a request to the scheduler and waits for the response."""
5353
try:
54-
# Broadcast request to all workers
55-
self.mq.enqueue(requests)
54+
# Prepare RPC request for generation
55+
rpc_request = {
56+
"type": "rpc",
57+
"method": "generate",
58+
"args": (requests,),
59+
"kwargs": {},
60+
"output_rank": 0,
61+
"exec_all_ranks": True,
62+
}
63+
64+
# Broadcast RPC request to all workers
65+
self.mq.enqueue(rpc_request)
5666
# Wait for result from Rank 0 (or whoever sends it)
67+
5768
if self.result_mq is None:
5869
raise RuntimeError("Result queue not initialized")
5970

vllm_omni/diffusion/worker/gpu_worker.py

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from vllm_omni.diffusion.cache.selector import get_cache_backend
1515
from vllm_omni.diffusion.data import (
16-
SHUTDOWN_MESSAGE,
1716
DiffusionOutput,
1817
OmniDiffusionConfig,
1918
set_current_omni_diffusion_config,
@@ -107,6 +106,18 @@ def init_device_and_model(self) -> None:
107106
if self.cache_backend is not None:
108107
self.cache_backend.enable(self.pipeline)
109108

109+
def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
110+
"""
111+
Generate output for the given requests.
112+
113+
Args:
114+
requests: List of diffusion requests
115+
116+
Returns:
117+
DiffusionOutput with generated results
118+
"""
119+
return self.execute_model(requests, self.od_config)
120+
110121
@torch.inference_mode()
111122
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
112123
"""
@@ -141,7 +152,7 @@ def __init__(
141152
# Inter-process Communication
142153
self.context = zmq.Context(io_threads=2)
143154

144-
# Initialize MessageQueue reader from handle
155+
# Initialize MessageQueue reader from handle (unified for generation & RPC)
145156
self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id)
146157

147158
self.result_mq = None
@@ -173,55 +184,97 @@ def return_result(self, output: DiffusionOutput):
173184
if self.result_mq is not None:
174185
self.result_mq.enqueue(output)
175186

176-
def recv_reqs(self):
187+
def recv_message(self):
177188
"""
178-
Receive requests from broadcast queue
189+
Receive unified messages (RPC requests, shutdown) from broadcast queue.
190+
Uses indefinite=True to block until a message arrives.
179191
"""
180192
return self.mq.dequeue(indefinite=True)
181193

194+
def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]:
195+
"""Execute an RPC request and indicate whether to reply."""
196+
197+
method = rpc_request["method"]
198+
args = rpc_request.get("args", ())
199+
kwargs = rpc_request.get("kwargs", {})
200+
output_rank = rpc_request.get("output_rank")
201+
exec_all_ranks = rpc_request.get("exec_all_ranks", False)
202+
203+
should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id
204+
should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None
205+
206+
if not should_execute:
207+
return None, False
208+
209+
try:
210+
if isinstance(method, str):
211+
func = getattr(self.worker, method)
212+
result = func(*args, **kwargs)
213+
else:
214+
result = method(self.worker, *args, **kwargs)
215+
return result, should_reply
216+
except Exception as e:
217+
logger.error(f"Error executing RPC: {e}", exc_info=True)
218+
return {"status": "error", "error": str(e)}, should_reply
219+
182220
# TODO: queueing, cancellation
183221
def worker_busy_loop(self) -> None:
184222
"""Main busy loop for Multiprocessing Workers"""
185223

186224
logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory")
187225

188226
while self._running:
189-
reqs = None
190-
# 1: receive requests
227+
# Receive unified message (generation request, RPC request, or shutdown)
228+
msg = None
191229
try:
192-
reqs = self.recv_reqs()
230+
msg = self.recv_message()
193231
except Exception as e:
194232
logger.error(
195-
f"Error receiving requests in scheduler event loop: {e}",
233+
f"Error receiving message in worker loop: {e}",
196234
exc_info=True,
197235
)
198236
continue
199237

200-
if reqs == SHUTDOWN_MESSAGE:
201-
logger.info("Worker %s: Received shutdown message", self.gpu_id)
202-
self._running = False
203-
continue
204-
if reqs is None:
238+
if msg is None:
205239
logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id)
206240
continue
207241

208-
# 2: execute, make sure a reply is always sent
209-
try:
210-
output = self.worker.execute_model(reqs, self.od_config)
211-
except Exception as e:
212-
logger.error(
213-
f"Error executing forward in event loop: {e}",
214-
exc_info=True,
215-
)
216-
output = DiffusionOutput(error=str(e))
217-
218-
try:
219-
self.return_result(output)
220-
except zmq.ZMQError as e:
221-
# Reply failed; log and keep loop alive to accept future requests
222-
logger.error(f"ZMQ error sending reply: {e}")
242+
# Route message based on type
243+
if isinstance(msg, dict) and msg.get("type") == "rpc":
244+
# Handle RPC request
245+
try:
246+
result, should_reply = self.execute_rpc(msg)
247+
if should_reply:
248+
self.return_result(result)
249+
except Exception as e:
250+
logger.error(f"Error processing RPC: {e}", exc_info=True)
251+
if self.result_mq is not None:
252+
self.return_result({"status": "error", "error": str(e)})
253+
254+
elif isinstance(msg, dict) and msg.get("type") == "shutdown":
255+
# Handle shutdown message
256+
logger.info("Worker %s: Received shutdown message", self.gpu_id)
257+
self._running = False
223258
continue
224259

260+
else:
261+
# Handle generation request (OmniDiffusionRequest list)
262+
try:
263+
output = self.worker.execute_model(msg, self.od_config)
264+
except Exception as e:
265+
logger.error(
266+
f"Error executing forward in event loop: {e}",
267+
exc_info=True,
268+
)
269+
output = DiffusionOutput(error=str(e))
270+
271+
try:
272+
self.return_result(output)
273+
except zmq.ZMQError as e:
274+
# Reply failed; log and keep loop alive to accept future requests
275+
logger.error(f"ZMQ error sending reply: {e}")
276+
continue
277+
225278
logger.info("event loop terminated.")
226279
try:
227280
self.worker.shutdown()

0 commit comments

Comments
 (0)