Skip to content

Commit 9bd6f00

Browse files
iwzbiwith1015
authored andcommitted
[Bugfix] Fix stable diffusion3 compatibility error (vllm-project#772)
Signed-off-by: iwzbi <wzbi@zju.edu.cn>
1 parent 6edd837 commit 9bd6f00

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

vllm_omni/diffusion/models/sd3/sd3_transformer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def __init__(
102102
else:
103103
self.to_out = None
104104

105-
self.norm_added_q = RMSNorm(head_dim, eps=eps)
106-
self.norm_added_k = RMSNorm(head_dim, eps=eps)
105+
self.norm_added_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
106+
self.norm_added_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
107107

108108
self.attn = Attention(
109109
num_heads=num_heads,
@@ -341,8 +341,10 @@ def __init__(
341341
self.pooled_projection_dim = model_config.pooled_projection_dim
342342
self.joint_attention_dim = model_config.joint_attention_dim
343343
self.patch_size = model_config.patch_size
344-
self.dual_attention_layers = model_config.dual_attention_layers
345-
self.qk_norm = model_config.qk_norm
344+
self.dual_attention_layers = (
345+
model_config.dual_attention_layers if hasattr(model_config, "dual_attention_layers") else ()
346+
)
347+
self.qk_norm = model_config.qk_norm if hasattr(model_config, "qk_norm") else ""
346348
self.pos_embed_max_size = model_config.pos_embed_max_size
347349

348350
self.pos_embed = PatchEmbed(

0 commit comments

Comments
 (0)