Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/getting_started/installation/npu/npu.inc.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export DEVICE1=/dev/davinci1
export IMAGE=quay.io/ascend/vllm-ascend:v0.11.0rc2
docker run --rm \
--name vllm-omni-npu \
--shm-size=1g \
--device $DEVICE0 \
--device $DEVICE1 \
--device /dev/davinci_manager \
Expand Down
11 changes: 9 additions & 2 deletions examples/offline_inference/qwen_image/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

ASPECT_RATIOS: dict[str, tuple[int, int]] = {
"1:1": (1328, 1328),
Expand Down Expand Up @@ -59,7 +59,14 @@ def parse_args() -> argparse.Namespace:

@lru_cache(maxsize=1)
def get_omni(model_name: str) -> Omni:
return Omni(model=model_name)
# Enable VAE memory optimizations on NPU
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()
return Omni(
model=model_name,
vae_use_slicing=vae_use_slicing,
vae_use_tiling=vae_use_tiling,
)


def build_demo(args: argparse.Namespace) -> gr.Blocks:
Expand Down
12 changes: 10 additions & 2 deletions examples/offline_inference/qwen_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type
from vllm_omni.utils.platform_utils import detect_device_type, is_npu


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -49,7 +49,15 @@ def main():
device = detect_device_type()
generator = torch.Generator(device=device).manual_seed(args.seed)

omni = Omni(model=args.model)
# Enable VAE memory optimizations on NPU
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()

omni = Omni(
model=args.model,
vae_use_slicing=vae_use_slicing,
vae_use_tiling=vae_use_tiling,
)
images = omni.generate(
args.prompt,
height=args.height,
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class OmniDiffusionConfig:
vae_cpu_offload: bool = True
pin_cpu_memory: bool = True

# VAE memory optimization parameters
vae_use_slicing: bool = False
vae_use_tiling: bool = False

# STA (Sliding Tile Attention) parameters
mask_strategy_file_path: str | None = None
# STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
Expand Down
7 changes: 7 additions & 0 deletions vllm_omni/diffusion/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def initialize_model(
module = importlib.import_module(module_name)
model_class = getattr(module, cls_name)
model = model_class(od_config=od_config, prefix=mod_relname)

# Configure VAE memory optimization settings from config
if od_config.vae_use_slicing:
model.vae.use_slicing = True
if od_config.vae_use_tiling:
model.vae.use_tiling = True

return model
else:
raise ValueError(f"Model class {od_config.model_class_name} not found in diffusion model registry.")
Expand Down
10 changes: 8 additions & 2 deletions vllm_omni/diffusion/worker/npu/npu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch
import zmq
from vllm.config import VllmConfig, set_current_vllm_config
from transformers import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.parallel_state import (
init_distributed_environment,
Expand Down Expand Up @@ -53,7 +54,12 @@ def init_device_and_model(self) -> None:
torch.npu.set_device(device)

# hack
vllm_config = VllmConfig()
# set hf_config to a fake one to avolid get attr error
class FakePretrainedConfig(PretrainedConfig):
def __getattr__(self, name):
return "fake"

vllm_config = VllmConfig(model_config=ModelConfig(hf_config=FakePretrainedConfig()))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Instantiate ModelConfig without required model

NPU worker initialization now builds ModelConfig(hf_config=FakePretrainedConfig()) without supplying the mandatory model path. ModelConfig’s constructor/post-init requires a model identifier and will raise a TypeError before set_current_vllm_config runs, so the NPU worker process crashes during startup and diffusion on NPU never initializes.

Useful? React with 👍 / 👎.

vllm_config.parallel_config.tensor_parallel_size = self.od_config.num_gpus
set_current_vllm_config(vllm_config)

Expand Down
Loading