@@ -604,7 +604,7 @@ def forward(
604604 txt_mod1 , txt_mod2 = txt_mod_params .chunk (2 , dim = - 1 ) # Each [B, 3*dim]
605605
606606 # Process image stream - norm1 + modulation
607- img_modulated , img_gate1 = self .img_norm1 (hidden_states , img_mod1 )
607+ img_modulated , img_gate1 = self .img_norm1 (hidden_states , img_mod1 , modulate_index )
608608
609609 # Process text stream - norm1 + modulation
610610 txt_modulated , txt_gate1 = self .txt_norm1 (encoder_hidden_states , txt_mod1 )
@@ -632,7 +632,8 @@ def forward(
632632 encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
633633
634634 # Process image stream - norm2 + MLP
635- img_modulated2 , img_gate2 = self .img_norm2 (hidden_states , img_mod2 )
635+ img_modulated2 , img_gate2 = self .img_norm2 (hidden_states , img_mod2 , modulate_index )
636+
636637 img_mlp_output = self .img_mlp (img_modulated2 )
637638 hidden_states = hidden_states + img_gate2 * img_mlp_output
638639
@@ -692,15 +693,13 @@ def __init__(
692693 attention_head_dim : int = 128 ,
693694 num_attention_heads : int = 24 ,
694695 joint_attention_dim : int = 3584 ,
695- guidance_embeds : bool = False , # TODO: this should probably be removed
696+ guidance_embeds : bool = False ,
696697 axes_dims_rope : tuple [int , int , int ] = (16 , 56 , 56 ),
697698 zero_cond_t : bool = False ,
698699 use_additional_t_cond : bool = False ,
699700 use_layer3d_rope : bool = False ,
700701 ):
701702 super ().__init__ ()
702- model_config = od_config .tf_model_config
703- num_layers = model_config .num_layers
704703 self .parallel_config = od_config .parallel_config
705704 self .in_channels = in_channels
706705 self .out_channels = out_channels or in_channels
0 commit comments