Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

import pytest
import torch
from vllm.logger import init_logger

logger = init_logger(__name__)


@pytest.fixture(autouse=True)
def clean_gpu_memory_between_tests():
if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1":
yield
return

# Wait for GPU memory to be cleared before starting the test
import gc

from tests.utils import wait_for_gpu_memory_to_clear

num_gpus = torch.cuda.device_count()
if num_gpus > 0:
try:
wait_for_gpu_memory_to_clear(
devices=list(range(num_gpus)),
threshold_ratio=0.1,
)
except ValueError as e:
logger.info("Failed to clean GPU memory: %s", e)

yield

# Clean up GPU memory after the test
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
15 changes: 13 additions & 2 deletions tests/single_stage/test_diffusion_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import os
import sys
from pathlib import Path

import pytest
import torch

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from vllm_omni import Omni

models = ["Tongyi-MAI/Z-Image-Turbo"]
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"

models = ["Tongyi-MAI/Z-Image-Turbo", "riverclouds/qwen_image_random"]


@pytest.mark.parametrize("model_name", models)
Expand All @@ -25,4 +36,4 @@ def test_diffusion_model(model_name: str):
# check image size
assert images[0].width == width
assert images[0].height == height
images[0].save("z_image_output.png")
images[0].save("image_output.png")
107 changes: 107 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import time
from contextlib import contextmanager

from vllm.platforms import current_platform

if current_platform.is_rocm():
from amdsmi import (
amdsmi_get_gpu_vram_usage,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
)

@contextmanager
def _nvml():
try:
amdsmi_init()
yield
finally:
amdsmi_shut_down()
elif current_platform.is_cuda():
from vllm.third_party.pynvml import (
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlInit,
nvmlShutdown,
)

@contextmanager
def _nvml():
try:
nvmlInit()
yield
finally:
nvmlShutdown()
else:

@contextmanager
def _nvml():
yield


def get_physical_device_indices(devices):
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_devices is None:
return devices

visible_indices = [int(x) for x in visible_devices.split(",")]
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
return [index_mapping[i] for i in devices if i in index_mapping]


@_nvml()
def wait_for_gpu_memory_to_clear(
*,
devices: list[int],
threshold_bytes: int | None = None,
threshold_ratio: float | None = None,
timeout_s: float = 120,
) -> None:
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
devices = get_physical_device_indices(devices)
start_time = time.time()
while True:
output: dict[int, str] = {}
output_raw: dict[int, tuple[float, float]] = {}
for device in devices:
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
gb_total = mem_info.total / 2**30
output_raw[device] = (gb_used, gb_total)
output[device] = f"{gb_used:.02f}/{gb_total:.02f}"

print("gpu memory used/total (GiB): ", end="")
for k, v in output.items():
print(f"{k}={v}; ", end="")
print("")

if threshold_bytes is not None:
is_free = lambda used, total: used <= threshold_bytes / 2**30 # noqa E731
threshold = f"{threshold_bytes / 2**30} GiB"
else:
is_free = lambda used, total: used / total <= threshold_ratio # noqa E731
threshold = f"{threshold_ratio:.2f}"

dur_s = time.time() - start_time
if all(is_free(used, total) for used, total in output_raw.values()):
print(f"Done waiting for free GPU memory on devices {devices=} ({threshold=}) {dur_s=:.02f}")
break

if dur_s >= timeout_s:
raise ValueError(f"Memory of devices {devices=} not free after {dur_s=:.02f} ({threshold=})")

time.sleep(5)
32 changes: 32 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,32 @@
logger = init_logger(__name__)


@dataclass
class TransformerConfig:
"""Container for raw transformer configuration dictionaries."""

params: dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TransformerConfig":
if not isinstance(data, dict):
raise TypeError(f"Expected transformer config dict, got {type(data)!r}")
return cls(params=dict(data))

def to_dict(self) -> dict[str, Any]:
return dict(self.params)

def get(self, key: str, default: Any | None = None) -> Any:
return self.params.get(key, default)

def __getattr__(self, item: str) -> Any:
params = object.__getattribute__(self, "params")
try:
return params[item]
except KeyError as exc:
raise AttributeError(item) from exc


@dataclass
class OmniDiffusionConfig:
# Model and path configuration (for convenience)
Expand All @@ -23,6 +49,8 @@ class OmniDiffusionConfig:

dtype: torch.dtype = torch.bfloat16

tf_model_config: TransformerConfig = field(default_factory=TransformerConfig)

# Attention
# attention_backend: str = None

Expand Down Expand Up @@ -214,3 +242,7 @@ class AttentionBackendEnum(enum.Enum):

def __str__(self):
return self.name.lower()


# Special message broadcast via scheduler queues to signal worker shutdown.
SHUTDOWN_MESSAGE = {"type": "shutdown"}
35 changes: 34 additions & 1 deletion vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from vllm.logger import init_logger

from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, OmniDiffusionConfig
from vllm_omni.diffusion.registry import get_diffusion_post_process_func
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.scheduler import scheduler
Expand All @@ -27,6 +27,8 @@ def __init__(self, od_config: OmniDiffusionConfig):

self.post_process_func = get_diffusion_post_process_func(od_config)

self._processes: list[mp.Process] = []
self._closed = False
self._make_client()

def step(self, requests: list[OmniDiffusionRequest]):
Expand Down Expand Up @@ -68,6 +70,8 @@ def _make_client(self):
else:
logger.error("Failed to get result queue handle from workers")

self._processes = processes

def _launch_workers(self, broadcast_handle):
od_config = self.od_config
logger.info("Starting server...")
Expand Down Expand Up @@ -131,3 +135,32 @@ def _launch_workers(self, broadcast_handle):

def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
return scheduler.add_req(requests)

def close(self, *, timeout_s: float = 30.0) -> None:
if self._closed:
return
self._closed = True

# Send shutdown signal to worker processes via broadcast queue
try:
if getattr(scheduler, "mq", None) is not None:
for _ in range(self.od_config.num_gpus or 1):
scheduler.mq.enqueue(SHUTDOWN_MESSAGE)
except Exception as exc: # pragma: no cover - best effort cleanup
logger.warning("Failed to send shutdown signal: %s", exc)

# Join all worker processes, terminate if they refuse to exit
for proc in self._processes:
if not proc.is_alive():
continue
proc.join(timeout_s)
if proc.is_alive():
logger.warning("Terminating diffusion worker %s after timeout", proc.name)
proc.terminate()
proc.join(timeout_s)

scheduler.close()
self._processes = []

def __del__(self): # pragma: no cover - best effort cleanup
self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
self.transformer = QwenImageTransformer2DModel()
self.transformer = QwenImageTransformer2DModel(od_config=od_config)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)

self.stage = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig

logger = init_logger(__name__)

Expand Down Expand Up @@ -486,6 +487,7 @@ class QwenImageTransformer2DModel(nn.Module):
# _repeated_blocks = ["QwenImageTransformerBlock"]
def __init__(
self,
od_config: OmniDiffusionConfig,
patch_size: int = 2,
in_channels: int = 64,
out_channels: Optional[int] = 16,
Expand All @@ -497,7 +499,8 @@ def __init__(
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.config = None
model_config = od_config.tf_model_config
num_layers = model_config.num_layers
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/models/z_image/pipeline_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ZImagePipeline(nn.Module):
def __init__(
self,
*,
od_config: OmniDiffusionConfig = None,
od_config: OmniDiffusionConfig,
prefix: str = "",
):
super().__init__()
Expand Down
16 changes: 15 additions & 1 deletion vllm_omni/diffusion/omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_file_to_dict

from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.request import OmniDiffusionRequest

Expand Down Expand Up @@ -53,6 +53,11 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs):
od_config.model,
)
od_config.model_class_name = config_dict.get("_class_name", None)
tf_config_dict = get_hf_file_to_dict(
"transformer/config.json",
od_config.model,
)
od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)

self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config)

Expand Down Expand Up @@ -82,3 +87,12 @@ def generate(

def _run_engine(self, requests: list[OmniDiffusionRequest]):
return self.engine.step(requests)

def close(self) -> None:
self.engine.close()

def __del__(self): # pragma: no cover - best effort cleanup
try:
self.close()
except Exception:
pass
6 changes: 5 additions & 1 deletion vllm_omni/diffusion/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __new__(cls, *args, **kwargs):
return cls._instance

def initialize(self, od_config: OmniDiffusionConfig):
if hasattr(self, "context") and not self.context.closed:
existing_context = getattr(self, "context", None)
if existing_context is not None and not existing_context.closed:
logger.warning("SyncSchedulerClient is already initialized. Re-initializing.")
self.close()

Expand Down Expand Up @@ -65,6 +66,9 @@ def close(self):
"""Closes the socket and terminates the context."""
if hasattr(self, "context"):
self.context.term()
self.context = None
self.mq = None
self.result_mq = None


# Singleton instance for easy access
Expand Down
Loading