11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- import multiprocessing as mp
54import os
65import time
7- import weakref
8- from collections .abc import Callable , Iterable
9- from dataclasses import dataclass
6+ from collections .abc import Iterable
107from typing import Any
118
129import PIL .Image
1310from vllm .logger import init_logger
1411
15- from vllm_omni .diffusion .data import SHUTDOWN_MESSAGE , OmniDiffusionConfig
12+ from vllm_omni .diffusion .data import OmniDiffusionConfig
13+ from vllm_omni .diffusion .executor .abstract import DiffusionExecutor
1614from vllm_omni .diffusion .registry import (
1715 DiffusionModelRegistry ,
1816 get_diffusion_post_process_func ,
1917 get_diffusion_pre_process_func ,
2018)
2119from vllm_omni .diffusion .request import OmniDiffusionRequest
22- from vllm_omni .diffusion .scheduler import Scheduler , scheduler
2320from vllm_omni .outputs import OmniRequestOutput
24- from vllm_omni .utils .platform_utils import get_diffusion_worker_class
2521
2622logger = init_logger (__name__ )
2723
@@ -33,39 +29,6 @@ def supports_image_input(model_class_name: str) -> bool:
3329 return bool (getattr (model_cls , "support_image_input" , False ))
3430
3531
36- @dataclass
37- class BackgroundResources :
38- """
39- Used as a finalizer for clean shutdown.
40- Create a BackgroundResources instance to encapsulate all background resources
41- (e.g., the scheduler and worker processes) that need explicit cleanup.
42- This object holds references to external system resources that are not managed
43- by Python's garbage collector (like OS processes, message queues, etc.),
44- so they must be cleaned up manually to avoid resource leaks or zombie processes.
45- """
46-
47- scheduler : Scheduler | None = None
48- processes : list [mp .Process ] | None = None
49-
50- def __call__ (self ):
51- """Clean up background resources."""
52- if scheduler is not None :
53- try :
54- for _ in range (scheduler .num_workers ):
55- scheduler .mq .enqueue (SHUTDOWN_MESSAGE )
56- scheduler .close ()
57- except Exception as exc :
58- logger .warning ("Failed to send shutdown signal: %s" , exc )
59- for proc in self .processes :
60- if not proc .is_alive ():
61- continue
62- proc .join (30 )
63- if proc .is_alive ():
64- logger .warning ("Terminating diffusion worker %s after timeout" , proc .name )
65- proc .terminate ()
66- proc .join (30 )
67-
68-
6932class DiffusionEngine :
7033 """The diffusion engine for vLLM-Omni diffusion models."""
7134
@@ -80,9 +43,9 @@ def __init__(self, od_config: OmniDiffusionConfig):
8043 self .post_process_func = get_diffusion_post_process_func (od_config )
8144 self .pre_process_func = get_diffusion_pre_process_func (od_config )
8245
83- self . _processes : list [ mp . Process ] = []
84- self ._closed = False
85- self . _make_client ()
46+ executor_class = DiffusionExecutor . get_class ( od_config )
47+ self .executor = executor_class ( od_config )
48+
8649 try :
8750 self ._dummy_run ()
8851 except Exception as e :
@@ -200,96 +163,8 @@ def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine":
200163 """
201164 return DiffusionEngine (config )
202165
203- def _make_client (self ):
204- # TODO rename it
205- scheduler .initialize (self .od_config )
206-
207- # Get the broadcast handle from the initialized scheduler
208- broadcast_handle = scheduler .get_broadcast_handle ()
209-
210- processes , result_handle = self ._launch_workers (
211- broadcast_handle = broadcast_handle ,
212- )
213-
214- if result_handle is not None :
215- scheduler .initialize_result_queue (result_handle )
216- else :
217- logger .error ("Failed to get result queue handle from workers" )
218-
219- self ._processes = processes
220-
221- self .resources = BackgroundResources (scheduler = scheduler , processes = self ._processes )
222- # Use weakref.finalize instead of __del__ or relying on self.close() at shutdown.
223- # During interpreter shutdown, global state (e.g., modules, built-ins) may already
224- # be cleared (set to None), so calling normal cleanup methods can fail with
225- # AttributeError: 'NoneType' object has no attribute '...'.
226- # weakref.finalize schedules cleanup *before* such destruction begins,
227- # ensuring resources are released while the runtime environment is still intact.
228- self ._finalizer = weakref .finalize (self , self .resources )
229-
230- def _launch_workers (self , broadcast_handle ):
231- od_config = self .od_config
232- logger .info ("Starting server..." )
233-
234- num_gpus = od_config .num_gpus
235- mp .set_start_method ("spawn" , force = True )
236- processes = []
237-
238- # Get the appropriate worker class for current device
239- worker_proc = get_diffusion_worker_class ()
240-
241- # Launch all worker processes
242- scheduler_pipe_readers = []
243- scheduler_pipe_writers = []
244-
245- for i in range (num_gpus ):
246- reader , writer = mp .Pipe (duplex = False )
247- scheduler_pipe_writers .append (writer )
248- process = mp .Process (
249- target = worker_proc .worker_main ,
250- args = (
251- i , # rank
252- od_config ,
253- writer ,
254- broadcast_handle ,
255- ),
256- name = f"DiffusionWorker-{ i } " ,
257- daemon = True ,
258- )
259- scheduler_pipe_readers .append (reader )
260- process .start ()
261- processes .append (process )
262-
263- # Wait for all workers to be ready
264- scheduler_infos = []
265- result_handle = None
266- for writer in scheduler_pipe_writers :
267- writer .close ()
268-
269- for i , reader in enumerate (scheduler_pipe_readers ):
270- try :
271- data = reader .recv ()
272- except EOFError :
273- logger .error (f"Rank { i } scheduler is dead. Please check if there are relevant logs." )
274- processes [i ].join ()
275- logger .error (f"Exit code: { processes [i ].exitcode } " )
276- raise
277-
278- if data ["status" ] != "ready" :
279- raise RuntimeError ("Initialization failed. Please see the error messages above." )
280-
281- if i == 0 :
282- result_handle = data .get ("result_handle" )
283-
284- scheduler_infos .append (data )
285- reader .close ()
286-
287- logger .debug ("All workers are ready" )
288-
289- return processes , result_handle
290-
291166 def add_req_and_wait_for_response (self , requests : list [OmniDiffusionRequest ]):
292- return scheduler .add_req (requests )
167+ return self . executor .add_req (requests )
293168
294169 def start_profile (self , trace_filename : str | None = None ) -> None :
295170 """
@@ -437,7 +312,7 @@ def _dummy_run(self):
437312
438313 def collective_rpc (
439314 self ,
440- method : str | Callable ,
315+ method : str ,
441316 timeout : float | None = None ,
442317 args : tuple = (),
443318 kwargs : dict | None = None ,
@@ -446,7 +321,7 @@ def collective_rpc(
446321 """Call a method on worker processes and get results immediately.
447322
448323 Args:
449- method: The method name (str) or callable to execute on workers
324+ method: The method name (str) to execute on workers
450325 timeout: Optional timeout in seconds
451326 args: Positional arguments for the method
452327 kwargs: Keyword arguments for the method
@@ -455,59 +330,18 @@ def collective_rpc(
455330 Returns:
456331 Single result if unique_reply_rank is provided, otherwise list of results
457332 """
458- if self ._closed :
459- raise RuntimeError ("DiffusionEngine is closed." )
460-
461- deadline = None if timeout is None else time .monotonic () + timeout
462- kwargs = kwargs or {}
463-
464- assert isinstance (method , str )
465- send_method = method
466-
467- # Prepare RPC request message
468- rpc_request = {
469- "type" : "rpc" ,
470- "method" : send_method ,
471- "args" : args ,
472- "kwargs" : kwargs ,
473- "output_rank" : unique_reply_rank ,
474- }
475-
476- try :
477- # Broadcast RPC request to all workers via unified message queue
478- scheduler .mq .enqueue (rpc_request )
479-
480- # Determine which workers we expect responses from
481- num_responses = 1 if unique_reply_rank is not None else self .od_config .num_gpus
482-
483- responses = []
484- for _ in range (num_responses ):
485- dequeue_timeout = None if deadline is None else (deadline - time .monotonic ())
486- try :
487- if scheduler .result_mq is None :
488- raise RuntimeError ("Result queue not initialized" )
489-
490- response = scheduler .result_mq .dequeue (timeout = dequeue_timeout )
491-
492- # Check if response indicates an error
493- if isinstance (response , dict ) and response .get ("status" ) == "error" :
494- raise RuntimeError (
495- f"Worker failed with error '{ response .get ('error' )} ', "
496- "please check the stack trace above for the root cause"
497- )
498-
499- responses .append (response )
500- except TimeoutError as e :
501- raise TimeoutError (f"RPC call to { method } timed out." ) from e
502-
503- return responses [0 ] if unique_reply_rank is not None else responses
504-
505- except Exception as e :
506- logger .error (f"RPC call failed: { e } " )
507- raise
333+ assert isinstance (method , str ), "Only string method names are supported for now"
334+ return self .executor .collective_rpc (
335+ method = method ,
336+ timeout = timeout ,
337+ args = args ,
338+ kwargs = kwargs ,
339+ unique_reply_rank = unique_reply_rank ,
340+ )
508341
509342 def close (self ) -> None :
510- self ._finalizer ()
343+ if hasattr (self , "executor" ):
344+ self .executor .shutdown ()
511345
512346 def abort (self , request_id : str | Iterable [str ]) -> None :
513347 # TODO implement it
0 commit comments