Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@
--prompt "Edit description" \
--cfg_parallel_size 2 \
--num_inference_steps 50 \
--cfg_scale 4.0 \
--cfg_scale 4.0

Usage (disable torch.compile):
python image_edit.py \
--image input.png \
--prompt "Edit description" \
--enforce_eager \
--num_inference_steps 50 \
--cfg_scale 4.0

For more options, run:
python image_edit.py --help
Expand Down Expand Up @@ -260,6 +268,11 @@ def parse_args() -> argparse.Namespace:
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallel size.",
)
parser.add_argument(
"--enforce_eager",
action="store_true",
help="Disable torch.compile and force eager execution.",
)
return parser.parse_args()


Expand Down Expand Up @@ -321,6 +334,7 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
)
print("Pipeline loaded")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def main():
parser.add_argument("--width", type=int, default=1024, help="Output image width")
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale")
parser.add_argument("--seed", type=int, help="Random seed")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--negative", help="Negative prompt")

args = parser.parse_args()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main():
parser.add_argument("--width", type=int, default=1024, help="Image width")
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
parser.add_argument("--cfg-scale", type=float, default=4.0, help="True CFG scale")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--negative", help="Negative prompt")

args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
def _dummy_run(self):
"""A dummy run to warm up the model."""
prompt = "dummy run"
# note that num_inference_steps=1 will cause timestep and temb None in the pipeline
num_inference_steps = 1
height = 1024
width = 1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
Expand Down Expand Up @@ -274,7 +275,8 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
self.transformer = QwenImageTransformer2DModel(od_config=od_config)
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)

self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
Expand Down Expand Up @@ -231,7 +232,8 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
self.transformer = QwenImageTransformer2DModel(od_config=od_config)
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.processor = Qwen2VLProcessor.from_pretrained(
model, subfolder="processor", local_files_only=local_files_only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
Expand Down Expand Up @@ -191,7 +192,9 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
self.transformer = QwenImageTransformer2DModel(od_config=od_config)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.processor = Qwen2VLProcessor.from_pretrained(
model, subfolder="processor", local_files_only=local_files_only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
Expand Down Expand Up @@ -211,18 +212,8 @@ def __init__(
)
]

use_additional_t_cond = od_config.tf_model_config.use_additional_t_cond
zero_cond_t = od_config.tf_model_config.zero_cond_t
use_layer3d_rope = od_config.tf_model_config.use_layer3d_rope
guidance_embeds = od_config.tf_model_config.guidance_embeds

self.transformer = QwenImageTransformer2DModel(
od_config=od_config,
use_additional_t_cond=use_additional_t_cond,
zero_cond_t=zero_cond_t,
use_layer3d_rope=use_layer3d_rope,
guidance_embeds=guidance_embeds,
)
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)

# Pipeline configuration & processing parameters
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def forward(
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]

# Process image stream - norm1 + modulation
img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1)
img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index)

# Process text stream - norm1 + modulation
txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1)
Expand Down Expand Up @@ -632,7 +632,8 @@ def forward(
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output

# Process image stream - norm2 + MLP
img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2)
img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index)

img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output

Expand Down Expand Up @@ -692,15 +693,13 @@ def __init__(
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
guidance_embeds: bool = False, # TODO: this should probably be removed
guidance_embeds: bool = False,
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool = False,
use_additional_t_cond: bool = False,
use_layer3d_rope: bool = False,
):
super().__init__()
model_config = od_config.tf_model_config
num_layers = model_config.num_layers
self.parallel_config = od_config.parallel_config
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
Expand Down
54 changes: 54 additions & 0 deletions vllm_omni/diffusion/utils/tf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import inspect
from typing import Any

from vllm_omni.diffusion.data import TransformerConfig


def get_transformer_config_kwargs(
tf_model_config: TransformerConfig, model_class: type[Any] | None = None
) -> dict[str, Any]:
"""
This function extracts parameters from a TransformerConfig instance and filters out internal
diffusers metadata keys (those starting with '_') that should not be passed to model initialization.
Also filters out parameters that are not accepted by the model's __init__ method (e.g., pooled_projection_dim
for QwenImageTransformer2DModel).

This uses inspect.signature to dynamically detect accepted parameters, making it general for any model class.
Similar to how diffusers' @register_to_config decorator works.

Args:
tf_model_config: TransformerConfig instance containing model parameters
model_class: Optional model class to inspect for accepted __init__ parameters.
If None, all non-internal parameters are returned (backward compatibility).

Returns:
dict: Filtered dictionary of parameters suitable for transformer model initialization
"""
# Extract transformer config parameters, filtering out internal diffusers metadata
# TransformerConfig stores params in a 'params' dict, and we need to exclude
# internal keys like '_class_name' and '_diffusers_version'
tf_config_params = tf_model_config.to_dict()

# Filter out internal diffusers metadata keys that start with '_'
filtered_params = {k: v for k, v in tf_config_params.items() if not k.startswith("_")}

# If model_class is provided, use inspect.signature to get accepted parameters
if model_class is not None:
try:
# Get the signature of the model's __init__ method
sig = inspect.signature(model_class.__init__)
# Get all parameter names (excluding 'self' and special parameters)
accepted_params = {
name
for name, param in sig.parameters.items()
if name != "self" and param.kind != inspect.Parameter.VAR_KEYWORD # Exclude **kwargs
}

# Filter to only include parameters that are in the model's signature
filtered_params = {k: v for k, v in filtered_params.items() if k in accepted_params}
except (TypeError, AttributeError):
# If inspection fails, fall back to returning all non-internal params
# This maintains backward compatibility
pass

return filtered_params