Skip to content

Commit 78cb759

Browse files
authored
Support Wan 2.X T2V/TI2V (#589)
* Rename wan file, add task param * Add support for t2v and ti2v * Fix parallelization for TI2V * Log where the video was saved. * Update README.md * Move img file check inside i2v task * Fix task-specific resize logic * Change pipeline based on i2v/t2v task * Fix ti2v parallel if i2v * Add missing dimension args
1 parent 39d2ee0 commit 78cb759

File tree

2 files changed

+70
-18
lines changed

2 files changed

+70
-18
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ The following open-sourced DiT Models are released with xDiT in day 1.
113113
| [🎬 Mochi-1](https://github.com/xdit-project/mochi-xdit) | ✔️ | ✔️ ||| [Report](https://github.com/xdit-project/mochi-xdit) |
114114
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ ||| [Report](./docs/performance/cogvideo.md) |
115115
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||| [Report](./docs/performance/latte.md) |
116-
| [🎬 Wan2.X I2V](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
116+
| [🎬 Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) || ✔️ ||| NA |
117+
| [🎬 Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
117118
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/hunyuandit.md) |
118119
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ || [Report](./docs/performance/flux.md) |
119120
| [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/pixart_alpha_legacy.md) |
@@ -236,7 +237,8 @@ Below is a list of validated diffusers version requirements. If the model is not
236237
| --- | --- |
237238
| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 |
238239
| [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | >= 0.35.2 |
239-
| [Wan2.X I2V](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |
240+
| [Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) | >= 0.35.2 |
241+
| [Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |
240242

241243
<h2 id="dev-guide">📚 Develop Guide</h2>
242244

examples/wan_i2v_example.py renamed to examples/wan_example.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
minimum_diffusers_version = get_minimum_diffusers_version("wan")
1010
raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Wan.")
1111

12-
from diffusers import WanImageToVideoPipeline
12+
from diffusers import WanImageToVideoPipeline, WanPipeline
1313
from diffusers.utils import export_to_video, load_image
1414
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1515

@@ -28,6 +28,18 @@
2828
)
2929
from xfuser.model_executor.models.transformers.transformer_wan import xFuserWanAttnProcessor
3030

31+
TASK_FPS = {
32+
"i2v": 16,
33+
"t2v": 16,
34+
"ti2v": 24,
35+
}
36+
37+
TASK_FLOW_SHIFT = {
38+
"i2v": 5,
39+
"t2v": 12,
40+
"ti2v": 5,
41+
}
42+
3143
# Wrapper to only wrap the transformer in case it exists, i.e. Wan2.2
3244
def maybe_transformer_2(transformer_2):
3345
if transformer_2 is not None:
@@ -103,6 +115,18 @@ def new_forward(
103115
], dim=1)
104116
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
105117

118+
if ts_seq_len is not None: # (wan2.2 ti2v)
119+
temb = torch.cat([
120+
temb,
121+
torch.zeros(batch_size, sequence_pad_amount, temb.shape[2], device=temb.device, dtype=temb.dtype)
122+
], dim=1)
123+
timestep_proj = torch.cat([
124+
timestep_proj,
125+
torch.zeros(batch_size, sequence_pad_amount, timestep_proj.shape[2], timestep_proj.shape[3], device=timestep_proj.device, dtype=timestep_proj.dtype)
126+
], dim=1)
127+
temb = torch.chunk(temb, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
128+
timestep_proj = torch.chunk(timestep_proj, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
129+
106130
freqs_cos, freqs_sin = rotary_emb
107131

108132
def get_rotary_emb_chunk(freqs, sequence_pad_amount):
@@ -183,47 +207,71 @@ def get_rotary_emb_chunk(freqs, sequence_pad_amount):
183207

184208
def main():
185209
parser = FlexibleArgumentParser(description="xFuser Arguments")
210+
parser.add_argument(
211+
"--task",
212+
type=str,
213+
required=True,
214+
choices=["i2v", "t2v", "ti2v"],
215+
help="The task to run."
216+
)
186217
args = xFuserArgs.add_cli_args(parser).parse_args()
187218
engine_args = xFuserArgs.from_cli_args(args)
188219
engine_config, input_config = engine_args.create_config()
189220
engine_config.runtime_config.dtype = torch.bfloat16
190221
local_rank = get_world_group().local_rank
191222
assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
192223

193-
if not args.img_file_path:
194-
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")
195-
196-
pipe = WanImageToVideoPipeline.from_pretrained(
224+
is_i2v_task = args.task == "i2v" or (args.task == "ti2v" and args.img_file_path != None)
225+
task_pipeline = WanImageToVideoPipeline if is_i2v_task else WanPipeline
226+
pipe = task_pipeline.from_pretrained(
197227
pretrained_model_name_or_path=engine_config.model_config.model,
198-
torch_dtype=torch.bfloat16
228+
torch_dtype=torch.bfloat16,
199229
)
200-
pipe.scheduler.config.flow_shift = 5 # Match original implementation
230+
pipe.scheduler.config.flow_shift = TASK_FLOW_SHIFT[args.task]
201231
initialize_runtime_state(pipe, engine_config)
202232
parallelize_transformer(pipe)
203233
pipe = pipe.to(f"cuda:{local_rank}")
204234

205-
image = load_image(args.img_file_path)
235+
if not args.img_file_path and args.task == "i2v":
236+
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")
206237

207-
max_area = input_config.height * input_config.width
208-
aspect_ratio = image.height / image.width
209-
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
210-
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
211-
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
212-
image = image.resize((width, height))
238+
if is_i2v_task:
239+
image = load_image(args.img_file_path)
240+
max_area = input_config.height * input_config.width
241+
aspect_ratio = image.height / image.width
242+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
243+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
244+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
245+
image = image.resize((width, height))
246+
if is_dp_last_group():
247+
print("Max area is calculated from input height and width values, but the aspect ratio for the output video is retained from the input image.")
248+
print(f"Input image resolution: {image.height}x{image.width}")
249+
print(f"Generating a video with resolution: {height}x{width}")
250+
else: # T2V or TI2V with no image
251+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
252+
height = input_config.height // mod_value * mod_value
253+
width = input_config.width // mod_value * mod_value
254+
if height != input_config.height or width != input_config.width:
255+
if is_dp_last_group():
256+
print(f"Adjusting height and width to be multiples of {mod_value}. New dimensions: {height}x{width}")
257+
image = None
213258

214259
def run_pipe(input_config, image):
215260
torch.cuda.reset_peak_memory_stats()
216261
torch.cuda.synchronize()
217262
start = time.perf_counter()
263+
optional_kwargs = {}
264+
if image:
265+
optional_kwargs["image"] = image
218266
output = pipe(
219267
height=height,
220268
width=width,
221-
image=image,
222269
prompt=input_config.prompt,
223270
num_inference_steps=input_config.num_inference_steps,
224271
num_frames=input_config.num_frames,
225272
guidance_scale=input_config.guidance_scale,
226273
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
274+
**optional_kwargs,
227275
).frames[0]
228276
end = time.perf_counter()
229277
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
@@ -243,7 +291,9 @@ def run_pipe(input_config, image):
243291

244292
output = run_pipe(input_config, image)
245293
if is_dp_last_group():
246-
export_to_video(output, "i2v_output.mp4", fps=16)
294+
file_name = f"{args.task}_output.mp4"
295+
export_to_video(output, file_name, fps=TASK_FPS[args.task])
296+
print(f"Output video saved to {file_name}")
247297

248298
get_runtime_state().destroy_distributed_env()
249299

0 commit comments

Comments
 (0)