Skip to content

Commit 2d5faf3

Browse files
authored
[Bugfix] Fix generation artifacts of Qwen-Image-Edit-2511 and update pipeline DiT param parsing (#776)
Signed-off-by: samithuang <285365963@qq.com>
1 parent 314bb4e commit 2d5faf3

10 files changed

Lines changed: 89 additions & 23 deletions

File tree

examples/offline_inference/image_to_image/image_edit.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@
5555
--prompt "Edit description" \
5656
--cfg_parallel_size 2 \
5757
--num_inference_steps 50 \
58-
--cfg_scale 4.0 \
58+
--cfg_scale 4.0
59+
60+
Usage (disable torch.compile):
61+
python image_edit.py \
62+
--image input.png \
63+
--prompt "Edit description" \
64+
--enforce_eager \
65+
--num_inference_steps 50 \
66+
--cfg_scale 4.0
5967
6068
For more options, run:
6169
python image_edit.py --help
@@ -260,6 +268,11 @@ def parse_args() -> argparse.Namespace:
260268
choices=[1, 2],
261269
help="Number of GPUs used for classifier free guidance parallel size.",
262270
)
271+
parser.add_argument(
272+
"--enforce_eager",
273+
action="store_true",
274+
help="Disable torch.compile and force eager execution.",
275+
)
263276
return parser.parse_args()
264277

265278

@@ -321,6 +334,7 @@ def main():
321334
cache_backend=args.cache_backend,
322335
cache_config=cache_config,
323336
parallel_config=parallel_config,
337+
enforce_eager=args.enforce_eager,
324338
)
325339
print("Pipeline loaded")
326340

examples/online_serving/image_to_image/openai_chat_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def main():
127127
parser.add_argument("--width", type=int, default=1024, help="Output image width")
128128
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
129129
parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale")
130-
parser.add_argument("--seed", type=int, help="Random seed")
130+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
131131
parser.add_argument("--negative", help="Negative prompt")
132132

133133
args = parser.parse_args()

examples/online_serving/text_to_image/openai_chat_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def main():
100100
parser.add_argument("--width", type=int, default=1024, help="Image width")
101101
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
102102
parser.add_argument("--cfg-scale", type=float, default=4.0, help="True CFG scale")
103-
parser.add_argument("--seed", type=int, default=42, help="Random seed")
103+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
104104
parser.add_argument("--negative", help="Negative prompt")
105105

106106
args = parser.parse_args()

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
293293
def _dummy_run(self):
294294
"""A dummy run to warm up the model."""
295295
prompt = "dummy run"
296+
# note that num_inference_steps=1 will cause timestep and temb None in the pipeline
296297
num_inference_steps = 1
297298
height = 1024
298299
width = 1024

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
QwenImageTransformer2DModel,
3737
)
3838
from vllm_omni.diffusion.request import OmniDiffusionRequest
39+
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
3940
from vllm_omni.model_executor.model_loader.weight_utils import (
4041
download_weights_from_hf_specific,
4142
)
@@ -274,7 +275,8 @@ def __init__(
274275
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
275276
self.device
276277
)
277-
self.transformer = QwenImageTransformer2DModel(od_config=od_config)
278+
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
279+
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
278280

279281
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
280282

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
QwenImageTransformer2DModel,
3939
)
4040
from vllm_omni.diffusion.request import OmniDiffusionRequest
41+
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
4142
from vllm_omni.model_executor.model_loader.weight_utils import (
4243
download_weights_from_hf_specific,
4344
)
@@ -231,7 +232,8 @@ def __init__(
231232
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
232233
self.device
233234
)
234-
self.transformer = QwenImageTransformer2DModel(od_config=od_config)
235+
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
236+
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
235237
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
236238
self.processor = Qwen2VLProcessor.from_pretrained(
237239
model, subfolder="processor", local_files_only=local_files_only

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
QwenImageTransformer2DModel,
4242
)
4343
from vllm_omni.diffusion.request import OmniDiffusionRequest
44+
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
4445
from vllm_omni.model_executor.model_loader.weight_utils import (
4546
download_weights_from_hf_specific,
4647
)
@@ -191,7 +192,9 @@ def __init__(
191192
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
192193
self.device
193194
)
194-
self.transformer = QwenImageTransformer2DModel(od_config=od_config)
195+
196+
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
197+
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
195198
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
196199
self.processor = Qwen2VLProcessor.from_pretrained(
197200
model, subfolder="processor", local_files_only=local_files_only

vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
QwenImageTransformer2DModel,
3838
)
3939
from vllm_omni.diffusion.request import OmniDiffusionRequest
40+
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
4041
from vllm_omni.model_executor.model_loader.weight_utils import (
4142
download_weights_from_hf_specific,
4243
)
@@ -211,18 +212,8 @@ def __init__(
211212
)
212213
]
213214

214-
use_additional_t_cond = od_config.tf_model_config.use_additional_t_cond
215-
zero_cond_t = od_config.tf_model_config.zero_cond_t
216-
use_layer3d_rope = od_config.tf_model_config.use_layer3d_rope
217-
guidance_embeds = od_config.tf_model_config.guidance_embeds
218-
219-
self.transformer = QwenImageTransformer2DModel(
220-
od_config=od_config,
221-
use_additional_t_cond=use_additional_t_cond,
222-
zero_cond_t=zero_cond_t,
223-
use_layer3d_rope=use_layer3d_rope,
224-
guidance_embeds=guidance_embeds,
225-
)
215+
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
216+
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
226217

227218
# Pipeline configuration & processing parameters
228219
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8

vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def forward(
604604
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
605605

606606
# Process image stream - norm1 + modulation
607-
img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1)
607+
img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index)
608608

609609
# Process text stream - norm1 + modulation
610610
txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1)
@@ -632,7 +632,8 @@ def forward(
632632
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
633633

634634
# Process image stream - norm2 + MLP
635-
img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2)
635+
img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index)
636+
636637
img_mlp_output = self.img_mlp(img_modulated2)
637638
hidden_states = hidden_states + img_gate2 * img_mlp_output
638639

@@ -692,15 +693,13 @@ def __init__(
692693
attention_head_dim: int = 128,
693694
num_attention_heads: int = 24,
694695
joint_attention_dim: int = 3584,
695-
guidance_embeds: bool = False, # TODO: this should probably be removed
696+
guidance_embeds: bool = False,
696697
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
697698
zero_cond_t: bool = False,
698699
use_additional_t_cond: bool = False,
699700
use_layer3d_rope: bool = False,
700701
):
701702
super().__init__()
702-
model_config = od_config.tf_model_config
703-
num_layers = model_config.num_layers
704703
self.parallel_config = od_config.parallel_config
705704
self.in_channels = in_channels
706705
self.out_channels = out_channels or in_channels
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import inspect
2+
from typing import Any
3+
4+
from vllm_omni.diffusion.data import TransformerConfig
5+
6+
7+
def get_transformer_config_kwargs(
8+
tf_model_config: TransformerConfig, model_class: type[Any] | None = None
9+
) -> dict[str, Any]:
10+
"""
11+
This function extracts parameters from a TransformerConfig instance and filters out internal
12+
diffusers metadata keys (those starting with '_') that should not be passed to model initialization.
13+
Also filters out parameters that are not accepted by the model's __init__ method (e.g., pooled_projection_dim
14+
for QwenImageTransformer2DModel).
15+
16+
This uses inspect.signature to dynamically detect accepted parameters, making it general for any model class.
17+
Similar to how diffusers' @register_to_config decorator works.
18+
19+
Args:
20+
tf_model_config: TransformerConfig instance containing model parameters
21+
model_class: Optional model class to inspect for accepted __init__ parameters.
22+
If None, all non-internal parameters are returned (backward compatibility).
23+
24+
Returns:
25+
dict: Filtered dictionary of parameters suitable for transformer model initialization
26+
"""
27+
# Extract transformer config parameters, filtering out internal diffusers metadata
28+
# TransformerConfig stores params in a 'params' dict, and we need to exclude
29+
# internal keys like '_class_name' and '_diffusers_version'
30+
tf_config_params = tf_model_config.to_dict()
31+
32+
# Filter out internal diffusers metadata keys that start with '_'
33+
filtered_params = {k: v for k, v in tf_config_params.items() if not k.startswith("_")}
34+
35+
# If model_class is provided, use inspect.signature to get accepted parameters
36+
if model_class is not None:
37+
try:
38+
# Get the signature of the model's __init__ method
39+
sig = inspect.signature(model_class.__init__)
40+
# Get all parameter names (excluding 'self' and special parameters)
41+
accepted_params = {
42+
name
43+
for name, param in sig.parameters.items()
44+
if name != "self" and param.kind != inspect.Parameter.VAR_KEYWORD # Exclude **kwargs
45+
}
46+
47+
# Filter to only include parameters that are in the model's signature
48+
filtered_params = {k: v for k, v in filtered_params.items() if k in accepted_params}
49+
except (TypeError, AttributeError):
50+
# If inspection fails, fall back to returning all non-internal params
51+
# This maintains backward compatibility
52+
pass
53+
54+
return filtered_params

0 commit comments

Comments
 (0)