Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ The following open-sourced DiT Models are released with xDiT in day 1.
| [🎬 Mochi-1](https://github.com/xdit-project/mochi-xdit) | ✔️ | ✔️ ||| [Report](https://github.com/xdit-project/mochi-xdit) |
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ ||| [Report](./docs/performance/cogvideo.md) |
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||| [Report](./docs/performance/latte.md) |
| [🎬 Wan2.X I2V](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
| [🎬 Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) || ✔️ ||| NA |
| [🎬 Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/hunyuandit.md) |
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ || [Report](./docs/performance/flux.md) |
| [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/pixart_alpha_legacy.md) |
Expand Down Expand Up @@ -236,7 +237,8 @@ Below is a list of validated diffusers version requirements. If the model is not
| --- | --- |
| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 |
| [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | >= 0.35.2 |
| [Wan2.X I2V](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |
| [Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) | >= 0.35.2 |
| [Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |

<h2 id="dev-guide">📚 Develop Guide</h2>

Expand Down
82 changes: 66 additions & 16 deletions examples/wan_i2v_example.py → examples/wan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
minimum_diffusers_version = get_minimum_diffusers_version("wan")
raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Wan.")

from diffusers import WanImageToVideoPipeline
from diffusers import WanImageToVideoPipeline, WanPipeline
from diffusers.utils import export_to_video, load_image
from diffusers.models.modeling_outputs import Transformer2DModelOutput

Expand All @@ -28,6 +28,18 @@
)
from xfuser.model_executor.models.transformers.transformer_wan import xFuserWanAttnProcessor

TASK_FPS = {
"i2v": 16,
"t2v": 16,
"ti2v": 24,
}

TASK_FLOW_SHIFT = {
"i2v": 5,
"t2v": 12,
"ti2v": 5,
}

# Wrapper to only wrap the transformer in case it exists, i.e. Wan2.2
def maybe_transformer_2(transformer_2):
if transformer_2 is not None:
Expand Down Expand Up @@ -103,6 +115,18 @@ def new_forward(
], dim=1)
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]

if ts_seq_len is not None: # (wan2.2 ti2v)
temb = torch.cat([
temb,
torch.zeros(batch_size, sequence_pad_amount, temb.shape[2], device=temb.device, dtype=temb.dtype)
], dim=1)
timestep_proj = torch.cat([
timestep_proj,
torch.zeros(batch_size, sequence_pad_amount, timestep_proj.shape[2], timestep_proj.shape[3], device=timestep_proj.device, dtype=timestep_proj.dtype)
], dim=1)
temb = torch.chunk(temb, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
timestep_proj = torch.chunk(timestep_proj, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]

freqs_cos, freqs_sin = rotary_emb

def get_rotary_emb_chunk(freqs, sequence_pad_amount):
Expand Down Expand Up @@ -183,47 +207,71 @@ def get_rotary_emb_chunk(freqs, sequence_pad_amount):

def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
parser.add_argument(
"--task",
type=str,
required=True,
choices=["i2v", "t2v", "ti2v"],
help="The task to run."
)
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank
assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."

if not args.img_file_path:
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")

pipe = WanImageToVideoPipeline.from_pretrained(
is_i2v_task = args.task == "i2v" or (args.task == "ti2v" and args.img_file_path != None)
task_pipeline = WanImageToVideoPipeline if is_i2v_task else WanPipeline
pipe = task_pipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16
torch_dtype=torch.bfloat16,
)
pipe.scheduler.config.flow_shift = 5 # Match original implementation
pipe.scheduler.config.flow_shift = TASK_FLOW_SHIFT[args.task]
initialize_runtime_state(pipe, engine_config)
parallelize_transformer(pipe)
pipe = pipe.to(f"cuda:{local_rank}")

image = load_image(args.img_file_path)
if not args.img_file_path and args.task == "i2v":
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")

max_area = input_config.height * input_config.width
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
if is_i2v_task:
image = load_image(args.img_file_path)
max_area = input_config.height * input_config.width
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
if is_dp_last_group():
print("Max area is calculated from input height and width values, but the aspect ratio for the output video is retained from the input image.")
print(f"Input image resolution: {image.height}x{image.width}")
print(f"Generating a video with resolution: {height}x{width}")
else: # T2V or TI2V with no image
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = input_config.height // mod_value * mod_value
width = input_config.width // mod_value * mod_value
if height != input_config.height or width != input_config.width:
if is_dp_last_group():
print(f"Adjusting height and width to be multiples of {mod_value}. New dimensions: {height}x{width}")
image = None

def run_pipe(input_config, image):
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start = time.perf_counter()
optional_kwargs = {}
if image:
optional_kwargs["image"] = image
output = pipe(
height=height,
width=width,
image=image,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
num_frames=input_config.num_frames,
guidance_scale=input_config.guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
**optional_kwargs,
).frames[0]
end = time.perf_counter()
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
Expand All @@ -243,7 +291,9 @@ def run_pipe(input_config, image):

output = run_pipe(input_config, image)
if is_dp_last_group():
export_to_video(output, "i2v_output.mp4", fps=16)
file_name = f"{args.task}_output.mp4"
export_to_video(output, file_name, fps=TASK_FPS[args.task])
print(f"Output video saved to {file_name}")

get_runtime_state().destroy_distributed_env()

Expand Down