Skip to content

Commit 24733c6

Browse files
author
bghira
committed
port concept of musubi-tuner wan_force_2_1_time_embedding
1 parent 8a84a9e commit 24733c6

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed

documentation/quickstart/WAN.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ You'll need:
4242
- **a realistic minimum** is 16GB or, a single 3090 or V100 GPU
4343
- **ideally** multiple 4090, A6000, L40S, or better
4444

45+
If you encounter shape mismatches in the time embedding layers when running Wan 2.2 checkpoints, enable the new
46+
`wan_force_2_1_time_embedding` flag. This forces the transformer to fall back to Wan 2.1 style time embeddings and
47+
resolves the compatibility issue.
48+
4549
Apple silicon systems do not work super well with Wan 2.1 so far, something like 10 minutes for a single training step can be expected..
4650

4751
### Prerequisites

simpletuner/helpers/models/wan/model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class Wan(VideoModelFoundation):
9292
def __init__(self, config, accelerator):
9393
super().__init__(config, accelerator)
9494
self._wan_cached_stage_modules: Dict[str, WanTransformer3DModel] = {}
95+
if not hasattr(self.config, "wan_force_2_1_time_embedding"):
96+
self.config.wan_force_2_1_time_embedding = False
9597

9698
def requires_conditioning_image_embeds(self) -> bool:
9799
return self._wan_stage_info() is not None
@@ -122,6 +124,14 @@ def _wan_stage_info(self) -> Optional[Dict[str, object]]:
122124
flavour = getattr(self.config, "model_flavour", None)
123125
return self.WAN_STAGE_OVERRIDES.get(flavour)
124126

127+
def _apply_time_embedding_override(self, transformer: Optional[WanTransformer3DModel]) -> None:
128+
if transformer is None:
129+
return
130+
target = self.unwrap_model(transformer)
131+
setter = getattr(target, "set_time_embedding_v2_1", None)
132+
if callable(setter):
133+
setter(bool(getattr(self.config, "wan_force_2_1_time_embedding", False)))
134+
125135
def _should_load_other_stage(self) -> bool:
126136
stage_info = self._wan_stage_info()
127137
if stage_info is None:
@@ -142,13 +152,19 @@ def _get_or_load_wan_stage_module(self, subfolder: str) -> WanTransformer3DModel
142152
stage.requires_grad_(False)
143153
stage.to(self.accelerator.device, dtype=self.config.weight_dtype)
144154
stage.eval()
155+
self._apply_time_embedding_override(stage)
145156
self._wan_cached_stage_modules[subfolder] = stage
146157
return stage
147158

148159
def unload_model(self):
149160
super().unload_model()
150161
self._wan_cached_stage_modules.clear()
151162

163+
def set_prepared_model(self, model, base_model: bool = False):
164+
super().set_prepared_model(model, base_model)
165+
if not base_model:
166+
self._apply_time_embedding_override(self.model)
167+
152168
def get_group_offload_components(self, pipeline):
153169
components = dict(super().get_group_offload_components(pipeline))
154170
transformer_2 = getattr(pipeline, "transformer_2", None)
@@ -183,6 +199,10 @@ def get_pipeline(self, pipeline_type: str = PipelineTypes.TEXT2IMG, load_base_mo
183199
else:
184200
pipeline.config.boundary_ratio = None
185201

202+
self._apply_time_embedding_override(getattr(pipeline, "transformer", None))
203+
if getattr(pipeline, "transformer_2", None) is not None:
204+
self._apply_time_embedding_override(pipeline.transformer_2)
205+
186206
return pipeline
187207

188208
def tread_init(self):

simpletuner/helpers/models/wan/transformer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,18 @@ def __init__(
454454
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
455455

456456
self.gradient_checkpointing = False
457+
self.force_v2_1_time_embedding: bool = False
458+
459+
def set_time_embedding_v2_1(self, force_2_1_time_embedding: bool) -> None:
460+
"""
461+
Force the Wan transformer to use 2.1-style time embeddings even when running Wan 2.2 checkpoints.
462+
463+
Args:
464+
force_2_1_time_embedding: Whether to override the default time embedding behaviour.
465+
"""
466+
self.force_v2_1_time_embedding = bool(force_2_1_time_embedding)
467+
if self.force_v2_1_time_embedding:
468+
logger.info("WanTransformer3DModel: Forcing Wan 2.1 style time embedding.")
457469

458470
def set_router(self, router: TREADRouter, routes: List[Dict[str, Any]]):
459471
"""Set the TREAD router and routing configuration."""
@@ -519,6 +531,11 @@ def forward(
519531
hidden_states = self.patch_embedding(hidden_states)
520532
hidden_states = hidden_states.flatten(2).transpose(1, 2)
521533

534+
if self.force_v2_1_time_embedding and timestep.dim() > 1:
535+
# Wan 2.1 uses a single timestep per batch entry. When forcing 2.1 behaviour with Wan 2.2
536+
# checkpoints we fall back to the first timestep value which matches the reference implementation.
537+
timestep = timestep[..., 0].contiguous()
538+
522539
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
523540
timestep, encoder_hidden_states, encoder_hidden_states_image
524541
)

simpletuner/simpletuner_sdk/server/services/field_registry/sections/model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,24 @@ def _quant_label(value: str) -> str:
506506
)
507507
)
508508

509+
registry._add_field(
510+
ConfigField(
511+
name="wan_force_2_1_time_embedding",
512+
arg_name="--wan_force_2_1_time_embedding",
513+
ui_label="Force Wan 2.1 Time Embedding",
514+
field_type=FieldType.CHECKBOX,
515+
tab="model",
516+
section="model_config",
517+
subsection="wan_specific",
518+
default_value=False,
519+
dependencies=[FieldDependency(field="model_family", operator="equals", value="wan", action="show")],
520+
help_text="Use Wan 2.1 style time embeddings even when running Wan 2.2 checkpoints.",
521+
tooltip="Enable this if Wan 2.2 checkpoints report shape mismatches in the time embedding layers.",
522+
importance=ImportanceLevel.ADVANCED,
523+
order=30,
524+
)
525+
)
526+
509527
# Fused QKV Projections
510528
registry._add_field(
511529
ConfigField(

0 commit comments

Comments
 (0)