Skip to content

Commit 33355b9

Browse files
author
bghira
committed
fix for test
1 parent 24733c6 commit 33355b9

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _filtered_warning(self, msg, *args, **kwargs):
3939
os.environ.setdefault("DATASETS_VERBOSITY", "error")
4040
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
4141
os.environ.setdefault("TQDM_DISABLE", "1")
42+
os.environ.setdefault("SIMPLETUNER_FAST_CONFIG_API", "1")
4243

4344
# Register cleanup for test directories
4445
import atexit

tests/test_transformers/test_wan_transformer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,34 @@ def test_forward_pass_with_image_encoder_states(self):
10671067
expected_shape = (batch_size, self.model_config["out_channels"], num_frames, height, width)
10681068
self.assert_tensor_shape(output_tensor, expected_shape)
10691069

1070+
def test_set_time_embedding_v2_1_toggle(self):
1071+
"""Ensure the helper toggles the internal force flag."""
1072+
model = WanTransformer3DModel(**self.model_config)
1073+
self.assertFalse(model.force_v2_1_time_embedding)
1074+
model.set_time_embedding_v2_1(True)
1075+
self.assertTrue(model.force_v2_1_time_embedding)
1076+
model.set_time_embedding_v2_1(False)
1077+
self.assertFalse(model.force_v2_1_time_embedding)
1078+
1079+
def test_forward_time_embedding_override_with_sequence_timesteps(self):
1080+
"""Time embedding override should handle sequence-shaped timesteps without errors."""
1081+
model = WanTransformer3DModel(**self.model_config)
1082+
model.set_time_embedding_v2_1(True)
1083+
1084+
batch_size, in_channels, num_frames, height, width = 1, 16, 4, 8, 8
1085+
hidden_states = torch.randn(batch_size, in_channels, num_frames, height, width)
1086+
# Simulate Wan 2.2-style timestep tensor (batch, sequence_length)
1087+
sequence_length = num_frames // self.model_config["patch_size"][0]
1088+
timestep = torch.randint(0, 1000, (batch_size, sequence_length))
1089+
encoder_hidden_states = torch.randn(batch_size, 77, self.model_config["text_dim"])
1090+
1091+
with torch.no_grad():
1092+
output = model.forward(hidden_states=hidden_states, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
1093+
1094+
output_tensor = output.sample if hasattr(output, "sample") else output
1095+
expected_shape = (batch_size, self.model_config["out_channels"], num_frames, height, width)
1096+
self.assert_tensor_shape(output_tensor, expected_shape)
1097+
10701098
def test_3d_patch_embedding(self):
10711099
"""Test 3D patch embedding processing."""
10721100
model = WanTransformer3DModel(**self.model_config)

0 commit comments

Comments
 (0)