Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import torch
from PIL import Image

from vllm_omni.diffusion.data import DiffusionParallelConfig, logger
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
Expand Down Expand Up @@ -428,7 +428,6 @@ def main():

if not outputs:
raise ValueError("No output generated from omni.generate()")
logger.info("Outputs: %s", outputs)

# Extract images from OmniRequestOutput
# omni.generate() returns list[OmniRequestOutput], extract images from request_output[0].images
Expand Down
9 changes: 7 additions & 2 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def supports_image_input(model_class_name: str) -> bool:
return bool(getattr(model_cls, "support_image_input", False))


def image_color_format(model_class_name: str) -> str:
model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name)
return getattr(model_cls, "color_format", "RGB")


def supports_audio_output(model_class_name: str) -> bool:
model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name)
if model_cls is None:
Expand Down Expand Up @@ -315,8 +320,8 @@ def _dummy_run(self):
width = 1024
if supports_image_input(self.od_config.model_class_name):
# Provide a dummy image input if the model supports it

dummy_image = PIL.Image.new("RGB", (width, height), color=(0, 0, 0))
color_format = image_color_format(self.od_config.model_class_name)
dummy_image = PIL.Image.new(color_format, (width, height))
else:
dummy_image = None
prompt: OmniTextPrompt = {"prompt": "dummy run", "multi_modal_data": {"image": dummy_image}}
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@runtime_checkable
class SupportImageInput(Protocol):
support_image_input: ClassVar[bool] = True
color_format: ClassVar[str] = "RGB" # Default color format


@runtime_checkable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def retrieve_latents(


class QwenImageLayeredPipeline(nn.Module, SupportImageInput):
color_format = "RGBA"

def __init__(
self,
*,
Expand Down
Loading