Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1bdef5f
[BugFix] Fix text_to_audio output handling and gated model note
LudovicoYIN Jan 19, 2026
ffc1537
[BugFix] Simplify text_to_audio output handling
LudovicoYIN Jan 19, 2026
b422593
[BugFix] Standardize StableAudio audio output and update exampleâ€
LudovicoYIN Jan 19, 2026
a586389
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 20, 2026
c93c4b4
Add SupportOutputType for StableAudio output
LudovicoYIN Jan 20, 2026
d854f3c
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 20, 2026
18da606
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
3bba3dd
Merge branch 'vllm-project:main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
808203b
Add output_type protocol and rename SupportImageInput
LudovicoYIN Jan 21, 2026
73e33c8
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
5be10c9
Update stable audio test for multimodal_output
LudovicoYIN Jan 21, 2026
58e43a1
Add support_audio_output flag and rename SupportImageInput
LudovicoYIN Jan 21, 2026
c433b22
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
1149d50
Fix lint
LudovicoYIN Jan 21, 2026
8fa6b5f
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 22, 2026
60d96df
Replace output_type with support_audio_output helper
LudovicoYIN Jan 22, 2026
ad01f2b
Merge branch 'main' into fix-text-to-audio-output-and-docs
hsliuustc0106 Jan 22, 2026
8c6d35d
Merge branch 'main' into fix-text-to-audio-output-and-docs
hsliuustc0106 Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/offline_inference/text_to_audio/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def main():
generation_start = time.perf_counter()

# Generate audio
audio = omni.generate(
outputs = omni.generate(
args.prompt,
negative_prompt=args.negative_prompt,
generator=generator,
Expand All @@ -166,6 +166,21 @@ def main():
suffix = output_path.suffix or ".wav"
stem = output_path.stem or "stable_audio_output"

# Extract audio from omni.generate() outputs
if not outputs:
raise ValueError("No output generated from omni.generate()")

output = outputs[0]
if not hasattr(output, "request_output") or not output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")
request_output = output.request_output[0]
if not hasattr(request_output, "multimodal_output"):
raise ValueError("No multimodal_output found in request_output")

audio = request_output.multimodal_output.get("audio")
if audio is None:
raise ValueError("No audio output found in request_output")

# Handle different output formats
if isinstance(audio, torch.Tensor):
audio = audio.cpu().float().numpy()
Expand Down
76 changes: 54 additions & 22 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,18 @@ def step(self, requests: list[OmniDiffusionRequest]):
return None

postprocess_start_time = time.time()
images = self.post_process_func(output.output) if self.post_process_func is not None else output.output
outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output
postprocess_time = time.time() - postprocess_start_time
logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds")

# Convert to OmniRequestOutput format
# Ensure images is a list
if not isinstance(images, list):
images = [images] if images is not None else []
# Ensure outputs is a list
if not isinstance(outputs, list):
outputs = [outputs] if outputs is not None else []

model_cls = DiffusionModelRegistry._try_load_model_cls(self.od_config.model_class_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please define a util function for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I’ve added a small util function like def supports_image_input.

output_type = getattr(model_cls, "output_type", "image") if model_cls is not None else "image"
is_audio_output = output_type == "audio"

# Handle single request or multiple requests
if len(requests) == 1:
Expand All @@ -145,18 +149,30 @@ def step(self, requests: list[OmniDiffusionRequest]):
if output.trajectory_timesteps is not None:
metrics["trajectory_timesteps"] = output.trajectory_timesteps

return OmniRequestOutput.from_diffusion(
request_id=request_id,
images=images,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
)
if is_audio_output:
audio_payload = outputs[0] if len(outputs) == 1 else outputs
return OmniRequestOutput.from_diffusion(
request_id=request_id,
images=[],
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
multimodal_output={"audio": audio_payload},
final_output_type="audio",
)
else:
return OmniRequestOutput.from_diffusion(
request_id=request_id,
images=outputs,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
)
else:
# Multiple requests: return list of OmniRequestOutput
# Split images based on num_outputs_per_prompt for each request
results = []
image_idx = 0
output_idx = 0

for request in requests:
request_id = request.request_id or ""
Expand All @@ -166,22 +182,38 @@ def step(self, requests: list[OmniDiffusionRequest]):

# Get images for this request
num_outputs = request.num_outputs_per_prompt
request_images = images[image_idx : image_idx + num_outputs] if image_idx < len(images) else []
image_idx += num_outputs
request_outputs = (
outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else []
)
output_idx += num_outputs

metrics = {}
if output.trajectory_timesteps is not None:
metrics["trajectory_timesteps"] = output.trajectory_timesteps

results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
images=request_images,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
if is_audio_output:
audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs
results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
images=[],
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
multimodal_output={"audio": audio_payload},
final_output_type="audio",
)
)
else:
results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
images=request_outputs,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
)
)
)

return results
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
)
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
Expand Down Expand Up @@ -178,7 +178,7 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
return float(mu)


class Flux2KleinPipeline(nn.Module, SupportImageInput):
class Flux2KleinPipeline(nn.Module, SupportInputType):
"""Flux2 klein pipeline for text-to-image generation."""

support_image_input = True
Expand Down
7 changes: 6 additions & 1 deletion vllm_omni/diffusion/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@


@runtime_checkable
class SupportImageInput(Protocol):
class SupportInputType(Protocol):
support_image_input: ClassVar[bool] = True


@runtime_checkable
class SupportOutputType(Protocol):
output_type: ClassVar[str] = "image"
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import (
LongCatImageTransformer2DModel,
)
Expand Down Expand Up @@ -197,7 +197,7 @@ def split_quotation(prompt, quote_pairs=None):
return result


class LongCatImageEditPipeline(nn.Module, SupportImageInput):
class LongCatImageEditPipeline(nn.Module, SupportInputType):
def __init__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift
from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import (
QwenImageTransformer2DModel,
Expand Down Expand Up @@ -197,7 +197,7 @@ def retrieve_latents(

class QwenImageEditPipeline(
nn.Module,
SupportImageInput,
SupportInputType,
):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import calculate_shift
from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import (
calculate_dimensions,
Expand Down Expand Up @@ -158,7 +158,7 @@ def post_process_func(
return post_process_func


class QwenImageEditPlusPipeline(nn.Module, SupportImageInput):
class QwenImageEditPlusPipeline(nn.Module, SupportInputType):
def __init__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.models.qwen_image.autoencoder_kl_qwenimage import (
AutoencoderKLQwenImage,
)
Expand Down Expand Up @@ -172,7 +172,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


class QwenImageLayeredPipeline(nn.Module, SupportImageInput):
class QwenImageLayeredPipeline(nn.Module, SupportInputType):
def __init__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportOutputType
from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel
from vllm_omni.diffusion.request import OmniDiffusionRequest

Expand Down Expand Up @@ -57,7 +58,7 @@ def post_process_func(
return post_process_func


class StableAudioPipeline(nn.Module):
class StableAudioPipeline(nn.Module, SupportOutputType):
"""
Pipeline for text-to-audio generation using Stable Audio Open.

Expand All @@ -69,6 +70,8 @@ class StableAudioPipeline(nn.Module):
prefix: Weight prefix for loading (default: "")
"""

output_type: str = "audio"

def __init__(
self,
*,
Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
create_transformer_from_config,
Expand Down Expand Up @@ -114,7 +114,7 @@ def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusion
return pre_process_func


class Wan22I2VPipeline(nn.Module, SupportImageInput):
class Wan22I2VPipeline(nn.Module, SupportInputType):
"""
Wan2.2 Image-to-Video Pipeline.

Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.interface import SupportInputType
from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
create_transformer_from_config,
Expand Down Expand Up @@ -104,7 +104,7 @@ def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusion
return pre_process_func


class Wan22TI2VPipeline(nn.Module, SupportImageInput):
class Wan22TI2VPipeline(nn.Module, SupportInputType):
"""
Wan2.2 Text-Image-to-Video (TI2V) Pipeline.

Expand Down
7 changes: 6 additions & 1 deletion vllm_omni/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class OmniRequestOutput:
prompt: str | None = None
latents: torch.Tensor | None = None
metrics: dict[str, Any] = field(default_factory=dict)
multimodal_output: dict[str, Any] = field(default_factory=dict)

@classmethod
def from_pipeline(
Expand Down Expand Up @@ -88,6 +89,8 @@ def from_diffusion(
prompt: str | None = None,
metrics: dict[str, Any] | None = None,
latents: torch.Tensor | None = None,
multimodal_output: dict[str, Any] | None = None,
final_output_type: str = "image",
) -> "OmniRequestOutput":
"""Create output from diffusion model.

Expand All @@ -103,11 +106,12 @@ def from_diffusion(
"""
return cls(
request_id=request_id,
final_output_type="image",
final_output_type=final_output_type,
images=images,
prompt=prompt,
latents=latents,
metrics=metrics or {},
multimodal_output=multimodal_output or {},
finished=True,
)

Expand Down Expand Up @@ -168,6 +172,7 @@ def __repr__(self) -> str:
f"prompt={self.prompt!r}",
f"latents={self.latents}",
f"metrics={self.metrics}",
f"multimodal_output={self.multimodal_output}",
]

return f"OmniRequestOutput({', '.join(parts)})"