-
Notifications
You must be signed in to change notification settings - Fork 460
RPC support for OmniDiffusion #371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0835c74
854620c
9eb3db7
86e3319
971add7
98c6344
23e6c3e
e5ae766
2b169ed
7c6d3d4
2a42041
b2f3dde
7550475
c3dff61
b9d74f3
b99dca4
a86979f
add09c2
a372dfc
42a6fc1
515b2bd
e79cb0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,6 @@ | |
|
|
||
| from vllm_omni.diffusion.cache.selector import get_cache_backend | ||
| from vllm_omni.diffusion.data import ( | ||
| SHUTDOWN_MESSAGE, | ||
| DiffusionOutput, | ||
| OmniDiffusionConfig, | ||
| set_current_omni_diffusion_config, | ||
|
|
@@ -107,6 +106,18 @@ def init_device_and_model(self) -> None: | |
| if self.cache_backend is not None: | ||
| self.cache_backend.enable(self.pipeline) | ||
|
|
||
| def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: | ||
| """ | ||
| Generate output for the given requests. | ||
|
|
||
| Args: | ||
| requests: List of diffusion requests | ||
|
|
||
| Returns: | ||
| DiffusionOutput with generated results | ||
| """ | ||
| return self.execute_model(requests, self.od_config) | ||
|
|
||
| @torch.inference_mode() | ||
| def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: | ||
| """ | ||
|
|
@@ -141,7 +152,7 @@ def __init__( | |
| # Inter-process Communication | ||
| self.context = zmq.Context(io_threads=2) | ||
|
|
||
| # Initialize MessageQueue reader from handle | ||
| # Initialize MessageQueue reader from handle (unified for generation & RPC) | ||
| self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) | ||
|
|
||
| self.result_mq = None | ||
|
|
@@ -173,55 +184,97 @@ def return_result(self, output: DiffusionOutput): | |
| if self.result_mq is not None: | ||
| self.result_mq.enqueue(output) | ||
|
|
||
| def recv_reqs(self): | ||
| def recv_message(self): | ||
| """ | ||
| Receive requests from broadcast queue | ||
| Receive unified messages (RPC requests, shutdown) from broadcast queue. | ||
| Uses indefinite=True to block until a message arrives. | ||
| """ | ||
| return self.mq.dequeue(indefinite=True) | ||
|
|
||
| def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: | ||
| """Execute an RPC request and indicate whether to reply.""" | ||
|
|
||
| method = rpc_request["method"] | ||
| args = rpc_request.get("args", ()) | ||
| kwargs = rpc_request.get("kwargs", {}) | ||
| output_rank = rpc_request.get("output_rank") | ||
| exec_all_ranks = rpc_request.get("exec_all_ranks", False) | ||
|
|
||
| should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id | ||
| should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None | ||
|
|
||
| if not should_execute: | ||
| return None, False | ||
|
|
||
| try: | ||
| if isinstance(method, str): | ||
| func = getattr(self.worker, method) | ||
| result = func(*args, **kwargs) | ||
| else: | ||
| result = method(self.worker, *args, **kwargs) | ||
| return result, should_reply | ||
| except Exception as e: | ||
| logger.error(f"Error executing RPC: {e}", exc_info=True) | ||
| return {"status": "error", "error": str(e)}, should_reply | ||
|
|
||
| # TODO: queueing, cancellation | ||
| def worker_busy_loop(self) -> None: | ||
| """Main busy loop for Multiprocessing Workers""" | ||
|
|
||
| logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") | ||
|
|
||
| while self._running: | ||
| reqs = None | ||
| # 1: receive requests | ||
| # Receive unified message (generation request, RPC request, or shutdown) | ||
| msg = None | ||
| try: | ||
| reqs = self.recv_reqs() | ||
| msg = self.recv_message() | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error receiving requests in scheduler event loop: {e}", | ||
| f"Error receiving message in worker loop: {e}", | ||
| exc_info=True, | ||
| ) | ||
| continue | ||
|
|
||
| if reqs == SHUTDOWN_MESSAGE: | ||
| logger.info("Worker %s: Received shutdown message", self.gpu_id) | ||
| self._running = False | ||
| continue | ||
| if reqs is None: | ||
| if msg is None: | ||
| logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) | ||
| continue | ||
|
|
||
| # 2: execute, make sure a reply is always sent | ||
| try: | ||
| output = self.worker.execute_model(reqs, self.od_config) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error executing forward in event loop: {e}", | ||
| exc_info=True, | ||
| ) | ||
| output = DiffusionOutput(error=str(e)) | ||
|
|
||
| try: | ||
| self.return_result(output) | ||
| except zmq.ZMQError as e: | ||
| # Reply failed; log and keep loop alive to accept future requests | ||
| logger.error(f"ZMQ error sending reply: {e}") | ||
| # Route message based on type | ||
| if isinstance(msg, dict) and msg.get("type") == "rpc": | ||
| # Handle RPC request | ||
| try: | ||
| result, should_reply = self.execute_rpc(msg) | ||
| if should_reply: | ||
| self.return_result(result) | ||
|
Comment on lines
245
to
248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| except Exception as e: | ||
| logger.error(f"Error processing RPC: {e}", exc_info=True) | ||
| if self.result_mq is not None: | ||
| self.return_result({"status": "error", "error": str(e)}) | ||
|
|
||
| elif isinstance(msg, dict) and msg.get("type") == "shutdown": | ||
| # Handle shutdown message | ||
| logger.info("Worker %s: Received shutdown message", self.gpu_id) | ||
| self._running = False | ||
| continue | ||
|
|
||
| else: | ||
| # Handle generation request (OmniDiffusionRequest list) | ||
| try: | ||
| output = self.worker.execute_model(msg, self.od_config) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error executing forward in event loop: {e}", | ||
| exc_info=True, | ||
| ) | ||
| output = DiffusionOutput(error=str(e)) | ||
|
|
||
| try: | ||
| self.return_result(output) | ||
| except zmq.ZMQError as e: | ||
| # Reply failed; log and keep loop alive to accept future requests | ||
| logger.error(f"ZMQ error sending reply: {e}") | ||
| continue | ||
|
|
||
| logger.info("event loop terminated.") | ||
| try: | ||
| self.worker.shutdown() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should redesign this, because it has become more complex.
we can address the redesign in a separate, follow-up task if you don't have time.
here is a good example of
worker_busy_loop: https://github.com/vllm-project/vllm/blob/c02a2705f9ceeb00b5d32453621f997b2ceafbea/vllm/v1/executor/multiproc_executor.py#L806There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with it, the redesign is WIP and will need a more structure RFC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just confirming — is this already WIP?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is not WIP. But the redesign as you said above is on working
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZJY0516 It's ready now. Could you take a look again thanks ?