@@ -1067,6 +1067,34 @@ def test_forward_pass_with_image_encoder_states(self):
10671067 expected_shape = (batch_size , self .model_config ["out_channels" ], num_frames , height , width )
10681068 self .assert_tensor_shape (output_tensor , expected_shape )
10691069
1070+ def test_set_time_embedding_v2_1_toggle (self ):
1071+ """Ensure the helper toggles the internal force flag."""
1072+ model = WanTransformer3DModel (** self .model_config )
1073+ self .assertFalse (model .force_v2_1_time_embedding )
1074+ model .set_time_embedding_v2_1 (True )
1075+ self .assertTrue (model .force_v2_1_time_embedding )
1076+ model .set_time_embedding_v2_1 (False )
1077+ self .assertFalse (model .force_v2_1_time_embedding )
1078+
1079+ def test_forward_time_embedding_override_with_sequence_timesteps (self ):
1080+ """Time embedding override should handle sequence-shaped timesteps without errors."""
1081+ model = WanTransformer3DModel (** self .model_config )
1082+ model .set_time_embedding_v2_1 (True )
1083+
1084+ batch_size , in_channels , num_frames , height , width = 1 , 16 , 4 , 8 , 8
1085+ hidden_states = torch .randn (batch_size , in_channels , num_frames , height , width )
1086+ # Simulate Wan 2.2-style timestep tensor (batch, sequence_length)
1087+ sequence_length = num_frames // self .model_config ["patch_size" ][0 ]
1088+ timestep = torch .randint (0 , 1000 , (batch_size , sequence_length ))
1089+ encoder_hidden_states = torch .randn (batch_size , 77 , self .model_config ["text_dim" ])
1090+
1091+ with torch .no_grad ():
1092+ output = model .forward (hidden_states = hidden_states , timestep = timestep , encoder_hidden_states = encoder_hidden_states )
1093+
1094+ output_tensor = output .sample if hasattr (output , "sample" ) else output
1095+ expected_shape = (batch_size , self .model_config ["out_channels" ], num_frames , height , width )
1096+ self .assert_tensor_shape (output_tensor , expected_shape )
1097+
10701098 def test_3d_patch_embedding (self ):
10711099 """Test 3D patch embedding processing."""
10721100 model = WanTransformer3DModel (** self .model_config )
0 commit comments