|
13 | 13 |
|
14 | 14 | from vllm_omni.diffusion.cache.selector import get_cache_backend |
15 | 15 | from vllm_omni.diffusion.data import ( |
16 | | - SHUTDOWN_MESSAGE, |
17 | 16 | DiffusionOutput, |
18 | 17 | OmniDiffusionConfig, |
19 | 18 | set_current_omni_diffusion_config, |
@@ -107,6 +106,18 @@ def init_device_and_model(self) -> None: |
107 | 106 | if self.cache_backend is not None: |
108 | 107 | self.cache_backend.enable(self.pipeline) |
109 | 108 |
|
| 109 | + def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: |
| 110 | + """ |
| 111 | + Generate output for the given requests. |
| 112 | +
|
| 113 | + Args: |
| 114 | + requests: List of diffusion requests |
| 115 | +
|
| 116 | + Returns: |
| 117 | + DiffusionOutput with generated results |
| 118 | + """ |
| 119 | + return self.execute_model(requests, self.od_config) |
| 120 | + |
110 | 121 | @torch.inference_mode() |
111 | 122 | def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: |
112 | 123 | """ |
@@ -141,7 +152,7 @@ def __init__( |
141 | 152 | # Inter-process Communication |
142 | 153 | self.context = zmq.Context(io_threads=2) |
143 | 154 |
|
144 | | - # Initialize MessageQueue reader from handle |
| 155 | + # Initialize MessageQueue reader from handle (unified for generation & RPC) |
145 | 156 | self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) |
146 | 157 |
|
147 | 158 | self.result_mq = None |
@@ -173,55 +184,97 @@ def return_result(self, output: DiffusionOutput): |
173 | 184 | if self.result_mq is not None: |
174 | 185 | self.result_mq.enqueue(output) |
175 | 186 |
|
176 | | - def recv_reqs(self): |
| 187 | + def recv_message(self): |
177 | 188 | """ |
178 | | - Receive requests from broadcast queue |
| 189 | + Receive unified messages (RPC requests, shutdown) from broadcast queue. |
| 190 | + Uses indefinite=True to block until a message arrives. |
179 | 191 | """ |
180 | 192 | return self.mq.dequeue(indefinite=True) |
181 | 193 |
|
| 194 | + def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: |
| 195 | + """Execute an RPC request and indicate whether to reply.""" |
| 196 | + |
| 197 | + method = rpc_request["method"] |
| 198 | + args = rpc_request.get("args", ()) |
| 199 | + kwargs = rpc_request.get("kwargs", {}) |
| 200 | + output_rank = rpc_request.get("output_rank") |
| 201 | + exec_all_ranks = rpc_request.get("exec_all_ranks", False) |
| 202 | + |
| 203 | + should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id |
| 204 | + should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None |
| 205 | + |
| 206 | + if not should_execute: |
| 207 | + return None, False |
| 208 | + |
| 209 | + try: |
| 210 | + if isinstance(method, str): |
| 211 | + func = getattr(self.worker, method) |
| 212 | + result = func(*args, **kwargs) |
| 213 | + else: |
| 214 | + result = method(self.worker, *args, **kwargs) |
| 215 | + return result, should_reply |
| 216 | + except Exception as e: |
| 217 | + logger.error(f"Error executing RPC: {e}", exc_info=True) |
| 218 | + return {"status": "error", "error": str(e)}, should_reply |
| 219 | + |
182 | 220 | # TODO: queueing, cancellation |
183 | 221 | def worker_busy_loop(self) -> None: |
184 | 222 | """Main busy loop for Multiprocessing Workers""" |
185 | 223 |
|
186 | 224 | logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") |
187 | 225 |
|
188 | 226 | while self._running: |
189 | | - reqs = None |
190 | | - # 1: receive requests |
| 227 | + # Receive unified message (generation request, RPC request, or shutdown) |
| 228 | + msg = None |
191 | 229 | try: |
192 | | - reqs = self.recv_reqs() |
| 230 | + msg = self.recv_message() |
193 | 231 | except Exception as e: |
194 | 232 | logger.error( |
195 | | - f"Error receiving requests in scheduler event loop: {e}", |
| 233 | + f"Error receiving message in worker loop: {e}", |
196 | 234 | exc_info=True, |
197 | 235 | ) |
198 | 236 | continue |
199 | 237 |
|
200 | | - if reqs == SHUTDOWN_MESSAGE: |
201 | | - logger.info("Worker %s: Received shutdown message", self.gpu_id) |
202 | | - self._running = False |
203 | | - continue |
204 | | - if reqs is None: |
| 238 | + if msg is None: |
205 | 239 | logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) |
206 | 240 | continue |
207 | 241 |
|
208 | | - # 2: execute, make sure a reply is always sent |
209 | | - try: |
210 | | - output = self.worker.execute_model(reqs, self.od_config) |
211 | | - except Exception as e: |
212 | | - logger.error( |
213 | | - f"Error executing forward in event loop: {e}", |
214 | | - exc_info=True, |
215 | | - ) |
216 | | - output = DiffusionOutput(error=str(e)) |
217 | | - |
218 | | - try: |
219 | | - self.return_result(output) |
220 | | - except zmq.ZMQError as e: |
221 | | - # Reply failed; log and keep loop alive to accept future requests |
222 | | - logger.error(f"ZMQ error sending reply: {e}") |
| 242 | + # Route message based on type |
| 243 | + if isinstance(msg, dict) and msg.get("type") == "rpc": |
| 244 | + # Handle RPC request |
| 245 | + try: |
| 246 | + result, should_reply = self.execute_rpc(msg) |
| 247 | + if should_reply: |
| 248 | + self.return_result(result) |
| 249 | + except Exception as e: |
| 250 | + logger.error(f"Error processing RPC: {e}", exc_info=True) |
| 251 | + if self.result_mq is not None: |
| 252 | + self.return_result({"status": "error", "error": str(e)}) |
| 253 | + |
| 254 | + elif isinstance(msg, dict) and msg.get("type") == "shutdown": |
| 255 | + # Handle shutdown message |
| 256 | + logger.info("Worker %s: Received shutdown message", self.gpu_id) |
| 257 | + self._running = False |
223 | 258 | continue |
224 | 259 |
|
| 260 | + else: |
| 261 | + # Handle generation request (OmniDiffusionRequest list) |
| 262 | + try: |
| 263 | + output = self.worker.execute_model(msg, self.od_config) |
| 264 | + except Exception as e: |
| 265 | + logger.error( |
| 266 | + f"Error executing forward in event loop: {e}", |
| 267 | + exc_info=True, |
| 268 | + ) |
| 269 | + output = DiffusionOutput(error=str(e)) |
| 270 | + |
| 271 | + try: |
| 272 | + self.return_result(output) |
| 273 | + except zmq.ZMQError as e: |
| 274 | + # Reply failed; log and keep loop alive to accept future requests |
| 275 | + logger.error(f"ZMQ error sending reply: {e}") |
| 276 | + continue |
| 277 | + |
225 | 278 | logger.info("event loop terminated.") |
226 | 279 | try: |
227 | 280 | self.worker.shutdown() |
|
0 commit comments