Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/getting_started/installation/npu/npu.inc.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export DEVICE1=/dev/davinci1
export IMAGE=quay.io/ascend/vllm-ascend:v0.11.0rc2
docker run --rm \
--name vllm-omni-npu \
--shm-size=1g \
--device $DEVICE0 \
--device $DEVICE1 \
--device /dev/davinci_manager \
Expand Down
11 changes: 9 additions & 2 deletions examples/offline_inference/qwen_image/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type
from vllm_omni.utils.platform_utils import detect_device_type, is_npu

ASPECT_RATIOS: dict[str, tuple[int, int]] = {
"1:1": (1328, 1328),
Expand Down Expand Up @@ -59,7 +59,14 @@ def parse_args() -> argparse.Namespace:

@lru_cache(maxsize=1)
def get_omni(model_name: str) -> Omni:
return Omni(model=model_name)
# Enable VAE memory optimizations on NPU
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()
return Omni(
model=model_name,
vae_use_slicing=vae_use_slicing,
vae_use_tiling=vae_use_tiling,
)


def build_demo(args: argparse.Namespace) -> gr.Blocks:
Expand Down
12 changes: 10 additions & 2 deletions examples/offline_inference/qwen_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.utils.platform_utils import detect_device_type
from vllm_omni.utils.platform_utils import detect_device_type, is_npu


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -49,7 +49,15 @@ def main():
device = detect_device_type()
generator = torch.Generator(device=device).manual_seed(args.seed)

omni = Omni(model=args.model)
# Enable VAE memory optimizations on NPU
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()

omni = Omni(
model=args.model,
vae_use_slicing=vae_use_slicing,
vae_use_tiling=vae_use_tiling,
)
images = omni.generate(
args.prompt,
height=args.height,
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class OmniDiffusionConfig:
vae_cpu_offload: bool = True
pin_cpu_memory: bool = True

# VAE memory optimization parameters
vae_use_slicing: bool = False
vae_use_tiling: bool = False

# STA (Sliding Tile Attention) parameters
mask_strategy_file_path: str | None = None
# STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
Expand Down
5 changes: 5 additions & 0 deletions vllm_omni/diffusion/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def initialize_model(
module = importlib.import_module(module_name)
model_class = getattr(module, cls_name)
model = model_class(od_config=od_config, prefix=mod_relname)

# Configure VAE memory optimization settings from config
model.vae.use_slicing = od_config.vae_use_slicing
model.vae.use_tiling = od_config.vae_use_tiling

return model
else:
raise ValueError(f"Model class {od_config.model_class_name} not found in diffusion model registry.")
Expand Down
10 changes: 8 additions & 2 deletions vllm_omni/diffusion/worker/npu/npu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch
import zmq
from vllm.config import VllmConfig, set_current_vllm_config
from transformers import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.parallel_state import (
init_distributed_environment,
Expand Down Expand Up @@ -53,7 +54,12 @@ def init_device_and_model(self) -> None:
torch.npu.set_device(device)

# hack
vllm_config = VllmConfig()
# set hf_config to a fake one to avolid get attr error
class _FakePretrainedConfig(PretrainedConfig):
def __getattr__(self, name):
return "fake"

vllm_config = VllmConfig(model_config=ModelConfig(hf_config=_FakePretrainedConfig()))
vllm_config.parallel_config.tensor_parallel_size = self.od_config.num_gpus
set_current_vllm_config(vllm_config)

Expand Down