Skip to content

Commit c9d7cd1

Browse files
[Core]Add Diffusion executor (#865)
Signed-off-by: wzliu <wzliu@connect.hku.hk>
1 parent 4869311 commit c9d7cd1

5 files changed

Lines changed: 305 additions & 196 deletions

File tree

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 19 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import multiprocessing as mp
54
import os
65
import time
7-
import weakref
8-
from collections.abc import Callable, Iterable
9-
from dataclasses import dataclass
6+
from collections.abc import Iterable
107
from typing import Any
118

129
import PIL.Image
1310
from vllm.logger import init_logger
1411

15-
from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, OmniDiffusionConfig
12+
from vllm_omni.diffusion.data import OmniDiffusionConfig
13+
from vllm_omni.diffusion.executor.abstract import DiffusionExecutor
1614
from vllm_omni.diffusion.registry import (
1715
DiffusionModelRegistry,
1816
get_diffusion_post_process_func,
1917
get_diffusion_pre_process_func,
2018
)
2119
from vllm_omni.diffusion.request import OmniDiffusionRequest
22-
from vllm_omni.diffusion.scheduler import Scheduler, scheduler
2320
from vllm_omni.outputs import OmniRequestOutput
24-
from vllm_omni.utils.platform_utils import get_diffusion_worker_class
2521

2622
logger = init_logger(__name__)
2723

@@ -33,39 +29,6 @@ def supports_image_input(model_class_name: str) -> bool:
3329
return bool(getattr(model_cls, "support_image_input", False))
3430

3531

36-
@dataclass
37-
class BackgroundResources:
38-
"""
39-
Used as a finalizer for clean shutdown.
40-
Create a BackgroundResources instance to encapsulate all background resources
41-
(e.g., the scheduler and worker processes) that need explicit cleanup.
42-
This object holds references to external system resources that are not managed
43-
by Python's garbage collector (like OS processes, message queues, etc.),
44-
so they must be cleaned up manually to avoid resource leaks or zombie processes.
45-
"""
46-
47-
scheduler: Scheduler | None = None
48-
processes: list[mp.Process] | None = None
49-
50-
def __call__(self):
51-
"""Clean up background resources."""
52-
if scheduler is not None:
53-
try:
54-
for _ in range(scheduler.num_workers):
55-
scheduler.mq.enqueue(SHUTDOWN_MESSAGE)
56-
scheduler.close()
57-
except Exception as exc:
58-
logger.warning("Failed to send shutdown signal: %s", exc)
59-
for proc in self.processes:
60-
if not proc.is_alive():
61-
continue
62-
proc.join(30)
63-
if proc.is_alive():
64-
logger.warning("Terminating diffusion worker %s after timeout", proc.name)
65-
proc.terminate()
66-
proc.join(30)
67-
68-
6932
class DiffusionEngine:
7033
"""The diffusion engine for vLLM-Omni diffusion models."""
7134

@@ -80,9 +43,9 @@ def __init__(self, od_config: OmniDiffusionConfig):
8043
self.post_process_func = get_diffusion_post_process_func(od_config)
8144
self.pre_process_func = get_diffusion_pre_process_func(od_config)
8245

83-
self._processes: list[mp.Process] = []
84-
self._closed = False
85-
self._make_client()
46+
executor_class = DiffusionExecutor.get_class(od_config)
47+
self.executor = executor_class(od_config)
48+
8649
try:
8750
self._dummy_run()
8851
except Exception as e:
@@ -200,96 +163,8 @@ def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine":
200163
"""
201164
return DiffusionEngine(config)
202165

203-
def _make_client(self):
204-
# TODO rename it
205-
scheduler.initialize(self.od_config)
206-
207-
# Get the broadcast handle from the initialized scheduler
208-
broadcast_handle = scheduler.get_broadcast_handle()
209-
210-
processes, result_handle = self._launch_workers(
211-
broadcast_handle=broadcast_handle,
212-
)
213-
214-
if result_handle is not None:
215-
scheduler.initialize_result_queue(result_handle)
216-
else:
217-
logger.error("Failed to get result queue handle from workers")
218-
219-
self._processes = processes
220-
221-
self.resources = BackgroundResources(scheduler=scheduler, processes=self._processes)
222-
# Use weakref.finalize instead of __del__ or relying on self.close() at shutdown.
223-
# During interpreter shutdown, global state (e.g., modules, built-ins) may already
224-
# be cleared (set to None), so calling normal cleanup methods can fail with
225-
# AttributeError: 'NoneType' object has no attribute '...'.
226-
# weakref.finalize schedules cleanup *before* such destruction begins,
227-
# ensuring resources are released while the runtime environment is still intact.
228-
self._finalizer = weakref.finalize(self, self.resources)
229-
230-
def _launch_workers(self, broadcast_handle):
231-
od_config = self.od_config
232-
logger.info("Starting server...")
233-
234-
num_gpus = od_config.num_gpus
235-
mp.set_start_method("spawn", force=True)
236-
processes = []
237-
238-
# Get the appropriate worker class for current device
239-
worker_proc = get_diffusion_worker_class()
240-
241-
# Launch all worker processes
242-
scheduler_pipe_readers = []
243-
scheduler_pipe_writers = []
244-
245-
for i in range(num_gpus):
246-
reader, writer = mp.Pipe(duplex=False)
247-
scheduler_pipe_writers.append(writer)
248-
process = mp.Process(
249-
target=worker_proc.worker_main,
250-
args=(
251-
i, # rank
252-
od_config,
253-
writer,
254-
broadcast_handle,
255-
),
256-
name=f"DiffusionWorker-{i}",
257-
daemon=True,
258-
)
259-
scheduler_pipe_readers.append(reader)
260-
process.start()
261-
processes.append(process)
262-
263-
# Wait for all workers to be ready
264-
scheduler_infos = []
265-
result_handle = None
266-
for writer in scheduler_pipe_writers:
267-
writer.close()
268-
269-
for i, reader in enumerate(scheduler_pipe_readers):
270-
try:
271-
data = reader.recv()
272-
except EOFError:
273-
logger.error(f"Rank {i} scheduler is dead. Please check if there are relevant logs.")
274-
processes[i].join()
275-
logger.error(f"Exit code: {processes[i].exitcode}")
276-
raise
277-
278-
if data["status"] != "ready":
279-
raise RuntimeError("Initialization failed. Please see the error messages above.")
280-
281-
if i == 0:
282-
result_handle = data.get("result_handle")
283-
284-
scheduler_infos.append(data)
285-
reader.close()
286-
287-
logger.debug("All workers are ready")
288-
289-
return processes, result_handle
290-
291166
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
292-
return scheduler.add_req(requests)
167+
return self.executor.add_req(requests)
293168

294169
def start_profile(self, trace_filename: str | None = None) -> None:
295170
"""
@@ -437,7 +312,7 @@ def _dummy_run(self):
437312

438313
def collective_rpc(
439314
self,
440-
method: str | Callable,
315+
method: str,
441316
timeout: float | None = None,
442317
args: tuple = (),
443318
kwargs: dict | None = None,
@@ -446,7 +321,7 @@ def collective_rpc(
446321
"""Call a method on worker processes and get results immediately.
447322
448323
Args:
449-
method: The method name (str) or callable to execute on workers
324+
method: The method name (str) to execute on workers
450325
timeout: Optional timeout in seconds
451326
args: Positional arguments for the method
452327
kwargs: Keyword arguments for the method
@@ -455,59 +330,18 @@ def collective_rpc(
455330
Returns:
456331
Single result if unique_reply_rank is provided, otherwise list of results
457332
"""
458-
if self._closed:
459-
raise RuntimeError("DiffusionEngine is closed.")
460-
461-
deadline = None if timeout is None else time.monotonic() + timeout
462-
kwargs = kwargs or {}
463-
464-
assert isinstance(method, str)
465-
send_method = method
466-
467-
# Prepare RPC request message
468-
rpc_request = {
469-
"type": "rpc",
470-
"method": send_method,
471-
"args": args,
472-
"kwargs": kwargs,
473-
"output_rank": unique_reply_rank,
474-
}
475-
476-
try:
477-
# Broadcast RPC request to all workers via unified message queue
478-
scheduler.mq.enqueue(rpc_request)
479-
480-
# Determine which workers we expect responses from
481-
num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus
482-
483-
responses = []
484-
for _ in range(num_responses):
485-
dequeue_timeout = None if deadline is None else (deadline - time.monotonic())
486-
try:
487-
if scheduler.result_mq is None:
488-
raise RuntimeError("Result queue not initialized")
489-
490-
response = scheduler.result_mq.dequeue(timeout=dequeue_timeout)
491-
492-
# Check if response indicates an error
493-
if isinstance(response, dict) and response.get("status") == "error":
494-
raise RuntimeError(
495-
f"Worker failed with error '{response.get('error')}', "
496-
"please check the stack trace above for the root cause"
497-
)
498-
499-
responses.append(response)
500-
except TimeoutError as e:
501-
raise TimeoutError(f"RPC call to {method} timed out.") from e
502-
503-
return responses[0] if unique_reply_rank is not None else responses
504-
505-
except Exception as e:
506-
logger.error(f"RPC call failed: {e}")
507-
raise
333+
assert isinstance(method, str), "Only string method names are supported for now"
334+
return self.executor.collective_rpc(
335+
method=method,
336+
timeout=timeout,
337+
args=args,
338+
kwargs=kwargs,
339+
unique_reply_rank=unique_reply_rank,
340+
)
508341

509342
def close(self) -> None:
510-
self._finalizer()
343+
if hasattr(self, "executor"):
344+
self.executor.shutdown()
511345

512346
def abort(self, request_id: str | Iterable[str]) -> None:
513347
# TODO implement it

vllm_omni/diffusion/executor/__init__.py

Whitespace-only changes.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from vllm.utils.import_utils import resolve_obj_by_qualname
5+
6+
from vllm_omni.diffusion.data import OmniDiffusionConfig
7+
from vllm_omni.diffusion.request import OmniDiffusionRequest
8+
9+
10+
class DiffusionExecutor(ABC):
11+
"""Abstract base class for Diffusion executors."""
12+
13+
uses_multiproc: bool = False
14+
15+
@staticmethod
16+
def get_class(od_config: OmniDiffusionConfig) -> type["DiffusionExecutor"]:
17+
executor_class: type[DiffusionExecutor]
18+
distributed_executor_backend = od_config.distributed_executor_backend
19+
20+
if isinstance(distributed_executor_backend, type):
21+
if not issubclass(distributed_executor_backend, DiffusionExecutor):
22+
raise TypeError(
23+
"distributed_executor_backend must be a subclass of "
24+
f"DiffusionExecutor. Got {distributed_executor_backend}."
25+
)
26+
executor_class = distributed_executor_backend
27+
elif distributed_executor_backend == "ray":
28+
raise NotImplementedError("ray backend is not yet supported.")
29+
elif distributed_executor_backend == "mp":
30+
from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor
31+
32+
executor_class = MultiprocDiffusionExecutor
33+
elif distributed_executor_backend == "external_launcher":
34+
raise NotImplementedError("external_launcher backend is not yet supported.")
35+
elif isinstance(distributed_executor_backend, str):
36+
try:
37+
executor_class = resolve_obj_by_qualname(distributed_executor_backend)
38+
except (ImportError, ValueError) as e:
39+
raise ValueError(
40+
f"Failed to load executor backend '{distributed_executor_backend}'. "
41+
f"Ensure it is a valid python path. Error: {e}"
42+
) from e
43+
44+
if not issubclass(executor_class, DiffusionExecutor):
45+
raise TypeError(
46+
f"distributed_executor_backend must be a subclass of DiffusionExecutor. Got {executor_class}."
47+
)
48+
else:
49+
raise ValueError(f"Unknown distributed executor backend: {distributed_executor_backend}")
50+
return executor_class
51+
52+
def __init__(self, od_config: OmniDiffusionConfig):
53+
self.od_config = od_config
54+
self._init_executor()
55+
56+
@abstractmethod
57+
def _init_executor(self) -> None:
58+
"""Initialize the executor (e.g., launch workers, setup IPC)."""
59+
pass
60+
61+
@abstractmethod
62+
def add_req(self, requests: list[OmniDiffusionRequest]):
63+
"""Add requests to the execution queue."""
64+
pass
65+
66+
@abstractmethod
67+
def collective_rpc(
68+
self,
69+
method: str,
70+
timeout: float | None = None,
71+
args: tuple = (),
72+
kwargs: dict | None = None,
73+
unique_reply_rank: int | None = None,
74+
) -> Any:
75+
"""Execute a method on workers."""
76+
pass
77+
78+
@abstractmethod
79+
def check_health(self) -> None:
80+
"""Check if the executor and workers are healthy."""
81+
pass
82+
83+
@abstractmethod
84+
def shutdown(self) -> None:
85+
"""Shutdown the executor and release resources."""
86+
pass

0 commit comments

Comments
 (0)