diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py index 0b172a0773..ca1a455d14 100644 --- a/examples/offline_inference/text_to_audio/text_to_audio.py +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -142,7 +142,7 @@ def main(): generation_start = time.perf_counter() # Generate audio - audio = omni.generate( + outputs = omni.generate( args.prompt, negative_prompt=args.negative_prompt, generator=generator, @@ -166,6 +166,21 @@ def main(): suffix = output_path.suffix or ".wav" stem = output_path.stem or "stable_audio_output" + # Extract audio from omni.generate() outputs + if not outputs: + raise ValueError("No output generated from omni.generate()") + + output = outputs[0] + if not hasattr(output, "request_output") or not output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + request_output = output.request_output[0] + if not hasattr(request_output, "multimodal_output"): + raise ValueError("No multimodal_output found in request_output") + + audio = request_output.multimodal_output.get("audio") + if audio is None: + raise ValueError("No audio output found in request_output") + # Handle different output formats if isinstance(audio, torch.Tensor): audio = audio.cpu().float().numpy() diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py index 6485625733..729b501e49 100644 --- a/tests/e2e/offline_inference/test_stable_audio_model.py +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -44,15 +44,14 @@ def test_stable_audio_model(model_name: str): # Extract audio from OmniRequestOutput assert outputs is not None first_output = outputs[0] - assert first_output.final_output_type == "image" # Generic output type + assert first_output.final_output_type == "image" assert hasattr(first_output, "request_output") and first_output.request_output req_out = first_output.request_output[0] assert isinstance(req_out, OmniRequestOutput) - assert hasattr(req_out, "images") and len(req_out.images) >= 1 - - # For stable audio, the "images" field contains audio numpy arrays - audio = req_out.images[0] + assert req_out.final_output_type == "audio" + assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output + audio = req_out.multimodal_output.get("audio") assert isinstance(audio, np.ndarray) # audio shape: (batch, channels, samples) # For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 731ed425ef..e85ae0d76b 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -29,6 +29,13 @@ def supports_image_input(model_class_name: str) -> bool: return bool(getattr(model_cls, "support_image_input", False)) +def supports_audio_output(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_audio_output", False)) + + class DiffusionEngine: """The diffusion engine for vLLM-Omni diffusion models.""" @@ -86,14 +93,14 @@ def step(self, requests: list[OmniDiffusionRequest]): return None postprocess_start_time = time.time() - images = self.post_process_func(output.output) if self.post_process_func is not None else output.output + outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output postprocess_time = time.time() - postprocess_start_time logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds") # Convert to OmniRequestOutput format - # Ensure images is a list - if not isinstance(images, list): - images = [images] if images is not None else [] + # Ensure outputs is a list + if not isinstance(outputs, list): + outputs = [outputs] if outputs is not None else [] # Handle single request or multiple requests if len(requests) == 1: @@ -108,18 +115,30 @@ def step(self, requests: list[OmniDiffusionRequest]): if output.trajectory_timesteps is not None: metrics["trajectory_timesteps"] = output.trajectory_timesteps - return OmniRequestOutput.from_diffusion( - request_id=request_id, - images=images, - prompt=prompt, - metrics=metrics, - latents=output.trajectory_latents, - ) + if supports_audio_output(self.od_config.model_class_name): + audio_payload = outputs[0] if len(outputs) == 1 else outputs + return OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + multimodal_output={"audio": audio_payload}, + final_output_type="audio", + ) + else: + return OmniRequestOutput.from_diffusion( + request_id=request_id, + images=outputs, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ) else: # Multiple requests: return list of OmniRequestOutput # Split images based on num_outputs_per_prompt for each request results = [] - image_idx = 0 + output_idx = 0 for request in requests: request_id = request.request_id or "" @@ -129,22 +148,38 @@ def step(self, requests: list[OmniDiffusionRequest]): # Get images for this request num_outputs = request.num_outputs_per_prompt - request_images = images[image_idx : image_idx + num_outputs] if image_idx < len(images) else [] - image_idx += num_outputs + request_outputs = ( + outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else [] + ) + output_idx += num_outputs metrics = {} if output.trajectory_timesteps is not None: metrics["trajectory_timesteps"] = output.trajectory_timesteps - results.append( - OmniRequestOutput.from_diffusion( - request_id=request_id, - images=request_images, - prompt=prompt, - metrics=metrics, - latents=output.trajectory_latents, + if supports_audio_output(self.od_config.model_class_name): + audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs + results.append( + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + multimodal_output={"audio": audio_payload}, + final_output_type="audio", + ) + ) + else: + results.append( + OmniRequestOutput.from_diffusion( + request_id=request_id, + images=request_outputs, + prompt=prompt, + metrics=metrics, + latents=output.trajectory_latents, + ) ) - ) return results except Exception as e: diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py index 7c4cb33e6a..8e8ecbb776 100644 --- a/vllm_omni/diffusion/models/interface.py +++ b/vllm_omni/diffusion/models/interface.py @@ -10,3 +10,8 @@ @runtime_checkable class SupportImageInput(Protocol): support_image_input: ClassVar[bool] = True + + +@runtime_checkable +class SupportAudioOutput(Protocol): + support_audio_output: ClassVar[bool] = True diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index 2639f55eff..14516a9bb3 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -27,6 +27,7 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportAudioOutput from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -57,7 +58,7 @@ def post_process_func( return post_process_func -class StableAudioPipeline(nn.Module): +class StableAudioPipeline(nn.Module, SupportAudioOutput): """ Pipeline for text-to-audio generation using Stable Audio Open. diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 4e0583d77e..24b1682870 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -57,6 +57,7 @@ class OmniRequestOutput: prompt: str | None = None latents: torch.Tensor | None = None metrics: dict[str, Any] = field(default_factory=dict) + multimodal_output: dict[str, Any] = field(default_factory=dict) @classmethod def from_pipeline( @@ -91,6 +92,8 @@ def from_diffusion( prompt: str | None = None, metrics: dict[str, Any] | None = None, latents: torch.Tensor | None = None, + multimodal_output: dict[str, Any] | None = None, + final_output_type: str = "image", ) -> "OmniRequestOutput": """Create output from diffusion model. @@ -106,11 +109,12 @@ def from_diffusion( """ return cls( request_id=request_id, - final_output_type="image", + final_output_type=final_output_type, images=images, prompt=prompt, latents=latents, metrics=metrics or {}, + multimodal_output=multimodal_output or {}, finished=True, ) @@ -171,6 +175,7 @@ def __repr__(self) -> str: f"prompt={self.prompt!r}", f"latents={self.latents}", f"metrics={self.metrics}", + f"multimodal_output={self.multimodal_output}", ] return f"OmniRequestOutput({', '.join(parts)})"