Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1bdef5f
[BugFix] Fix text_to_audio output handling and gated model note
LudovicoYIN Jan 19, 2026
ffc1537
[BugFix] Simplify text_to_audio output handling
LudovicoYIN Jan 19, 2026
b422593
[BugFix] Standardize StableAudio audio output and update exampleâ€
LudovicoYIN Jan 19, 2026
a586389
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 20, 2026
c93c4b4
Add SupportOutputType for StableAudio output
LudovicoYIN Jan 20, 2026
d854f3c
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 20, 2026
18da606
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
3bba3dd
Merge branch 'vllm-project:main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
808203b
Add output_type protocol and rename SupportImageInput
LudovicoYIN Jan 21, 2026
73e33c8
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
5be10c9
Update stable audio test for multimodal_output
LudovicoYIN Jan 21, 2026
58e43a1
Add support_audio_output flag and rename SupportImageInput
LudovicoYIN Jan 21, 2026
c433b22
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 21, 2026
1149d50
Fix lint
LudovicoYIN Jan 21, 2026
8fa6b5f
Merge branch 'main' into fix-text-to-audio-output-and-docs
LudovicoYIN Jan 22, 2026
60d96df
Replace output_type with support_audio_output helper
LudovicoYIN Jan 22, 2026
ad01f2b
Merge branch 'main' into fix-text-to-audio-output-and-docs
hsliuustc0106 Jan 22, 2026
8c6d35d
Merge branch 'main' into fix-text-to-audio-output-and-docs
hsliuustc0106 Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/offline_inference/text_to_audio/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def main():
generation_start = time.perf_counter()

# Generate audio
audio = omni.generate(
outputs = omni.generate(
args.prompt,
negative_prompt=args.negative_prompt,
generator=generator,
Expand All @@ -166,6 +166,21 @@ def main():
suffix = output_path.suffix or ".wav"
stem = output_path.stem or "stable_audio_output"

# Extract audio from omni.generate() outputs
if not outputs:
raise ValueError("No output generated from omni.generate()")

output = outputs[0]
if not hasattr(output, "request_output") or not output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")
request_output = output.request_output[0]
if not hasattr(request_output, "multimodal_output"):
raise ValueError("No multimodal_output found in request_output")

audio = request_output.multimodal_output.get("audio")
if audio is None:
raise ValueError("No audio output found in request_output")

# Handle different output formats
if isinstance(audio, torch.Tensor):
audio = audio.cpu().float().numpy()
Expand Down
60 changes: 45 additions & 15 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def step(self, requests: list[OmniDiffusionRequest]):
if not isinstance(images, list):
images = [images] if images is not None else []

model_cls = DiffusionModelRegistry._try_load_model_cls(self.od_config.model_class_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please define a util function for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I’ve added a small util function like def supports_image_input.

output_type = getattr(model_cls, "output_type", "image") if model_cls is not None else "image"
is_audio_output = output_type == "audio"

# Handle single request or multiple requests
if len(requests) == 1:
# Single request: return single OmniRequestOutput
Expand All @@ -145,13 +149,25 @@ def step(self, requests: list[OmniDiffusionRequest]):
if output.trajectory_timesteps is not None:
metrics["trajectory_timesteps"] = output.trajectory_timesteps

return OmniRequestOutput.from_diffusion(
request_id=request_id,
images=images,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
)
if is_audio_output:
audio_payload = images[0] if len(images) == 1 else images
return OmniRequestOutput.from_diffusion(
request_id=request_id,
images=[],
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
multimodal_output={"audio": audio_payload},
final_output_type="audio",
)
else:
return OmniRequestOutput.from_diffusion(
request_id=request_id,
images=images,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
)
else:
# Multiple requests: return list of OmniRequestOutput
# Split images based on num_outputs_per_prompt for each request
Expand All @@ -173,15 +189,29 @@ def step(self, requests: list[OmniDiffusionRequest]):
if output.trajectory_timesteps is not None:
metrics["trajectory_timesteps"] = output.trajectory_timesteps

results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
images=request_images,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
if is_audio_output:
audio_payload = request_images[0] if len(request_images) == 1 else request_images
results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
images=[],
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
multimodal_output={"audio": audio_payload},
final_output_type="audio",
)
)
else:
results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
images=request_images,
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
)
)
)

return results
except Exception as e:
Expand Down
5 changes: 5 additions & 0 deletions vllm_omni/diffusion/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@
@runtime_checkable
class SupportImageInput(Protocol):
support_image_input: ClassVar[bool] = True


@runtime_checkable
class SupportOutputType(Protocol):
output_type: ClassVar[str] = "image"
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportOutputType
from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel
from vllm_omni.diffusion.request import OmniDiffusionRequest

Expand Down Expand Up @@ -57,7 +58,7 @@ def post_process_func(
return post_process_func


class StableAudioPipeline(nn.Module):
class StableAudioPipeline(nn.Module, SupportOutputType):
"""
Pipeline for text-to-audio generation using Stable Audio Open.

Expand All @@ -69,6 +70,8 @@ class StableAudioPipeline(nn.Module):
prefix: Weight prefix for loading (default: "")
"""

output_type: str = "audio"

def __init__(
self,
*,
Expand Down
7 changes: 6 additions & 1 deletion vllm_omni/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class OmniRequestOutput:
prompt: str | None = None
latents: torch.Tensor | None = None
metrics: dict[str, Any] = field(default_factory=dict)
multimodal_output: dict[str, Any] = field(default_factory=dict)

@classmethod
def from_pipeline(
Expand Down Expand Up @@ -88,6 +89,8 @@ def from_diffusion(
prompt: str | None = None,
metrics: dict[str, Any] | None = None,
latents: torch.Tensor | None = None,
multimodal_output: dict[str, Any] | None = None,
final_output_type: str = "image",
) -> "OmniRequestOutput":
"""Create output from diffusion model.

Expand All @@ -103,11 +106,12 @@ def from_diffusion(
"""
return cls(
request_id=request_id,
final_output_type="image",
final_output_type=final_output_type,
images=images,
prompt=prompt,
latents=latents,
metrics=metrics or {},
multimodal_output=multimodal_output or {},
finished=True,
)

Expand Down Expand Up @@ -168,6 +172,7 @@ def __repr__(self) -> str:
f"prompt={self.prompt!r}",
f"latents={self.latents}",
f"metrics={self.metrics}",
f"multimodal_output={self.multimodal_output}",
]

return f"OmniRequestOutput({', '.join(parts)})"