Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,10 @@ def diffuse(
def forward(
self,
req: OmniDiffusionRequest,
prompt: str | list[str] = None,
prompt_2: str | list[str] = None,
negative_prompt: str | list[str] = None,
negative_prompt_2: str | list[str] = None,
prompt: str | list[str] | None = None,
prompt_2: str | list[str] | None = None,
negative_prompt: str | list[str] | None = None,
negative_prompt_2: str | list[str] | None = None,
true_cfg_scale: float = 1.0,
height: int | None = None,
width: int | None = None,
Expand All @@ -581,16 +581,27 @@ def forward(
max_sequence_length: int = 512,
):
"""Forward pass for flux."""
prompt = req.prompt if req.prompt is not None else prompt
negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt
height = req.height or self.default_sample_size * self.vae_scale_factor
width = req.width or self.default_sample_size * self.vae_scale_factor
sigmas = req.sigmas or sigmas
num_inference_steps = req.num_inference_steps or num_inference_steps
generator = req.generator or generator
req_num_outputs = getattr(req, "num_outputs_per_prompt", None)
if req_num_outputs and req_num_outputs > 0:
num_images_per_prompt = req_num_outputs
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
# TODO: May be some data formatting operations on the API side. Hack for now.
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
negative_prompt = None
elif req.prompts:
negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]

height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
guidance_scale = (
req.sampling_params.guidance_scale if req.sampling_params.guidance_scale is not None else guidance_scale
)
generator = req.sampling_params.generator or generator
num_images_per_prompt = (
req.sampling_params.num_outputs_per_prompt
if req.sampling_params.num_outputs_per_prompt > 0
else num_images_per_prompt
)

# 1. Check inputs. Raise error if not correct
self.check_inputs(
Expand Down