Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ The following open-sourced DiT Models are released with xDiT in day 1.
| [🎬 Mochi-1](https://github.com/xdit-project/mochi-xdit) | ✔️ | ✔️ ||| [Report](https://github.com/xdit-project/mochi-xdit) |
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ ||| [Report](./docs/performance/cogvideo.md) |
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||| [Report](./docs/performance/latte.md) |
| [🎬 Wan2.X I2V](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
| [🎬 Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) || ✔️ ||| NA |
| [🎬 Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) || ✔️ ||| NA |
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/hunyuandit.md) |
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ || [Report](./docs/performance/flux.md) |
| [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ || [Report](./docs/performance/pixart_alpha_legacy.md) |
Expand Down Expand Up @@ -236,7 +237,8 @@ Below is a list of validated diffusers version requirements. If the model is not
| --- | --- |
| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 |
| [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | >= 0.35.2 |
| [Wan2.X I2V](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |
| [Wan2.1](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers) | >= 0.35.2 |
| [Wan2.2](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) | >= 0.35.2 |

<h2 id="dev-guide">📚 Develop Guide</h2>

Expand Down
78 changes: 62 additions & 16 deletions examples/wan_i2v_example.py → examples/wan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
minimum_diffusers_version = get_minimum_diffusers_version("wan")
raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Wan.")

from diffusers import WanImageToVideoPipeline
from diffusers import WanImageToVideoPipeline, WanPipeline
from diffusers.utils import export_to_video, load_image
from diffusers.models.modeling_outputs import Transformer2DModelOutput

Expand All @@ -28,6 +28,23 @@
)
from xfuser.model_executor.models.transformers.transformer_wan import xFuserWanAttnProcessor

TASK_PIPELINE = {
"i2v": WanImageToVideoPipeline,
"t2v": WanPipeline,
"ti2v": WanPipeline,
}
TASK_FPS = {
"i2v": 16,
"t2v": 16,
"ti2v": 24,
}

TASK_FLOW_SHIFT = {
"i2v": 5,
"t2v": 12,
"ti2v": 5,
}

# Wrapper to only wrap the transformer in case it exists, i.e. Wan2.2
def maybe_transformer_2(transformer_2):
if transformer_2 is not None:
Expand Down Expand Up @@ -103,6 +120,10 @@ def new_forward(
], dim=1)
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]

if ts_seq_len is not None: # (wan2.2 ti2v)
temb = torch.chunk(temb, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
timestep_proj = torch.chunk(timestep_proj, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]

freqs_cos, freqs_sin = rotary_emb

def get_rotary_emb_chunk(freqs, sequence_pad_amount):
Expand Down Expand Up @@ -183,47 +204,70 @@ def get_rotary_emb_chunk(freqs, sequence_pad_amount):

def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
parser.add_argument(
"--task",
type=str,
required=True,
choices=["i2v", "t2v", "ti2v"],
help="The task to run."
)
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank
assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."

if not args.img_file_path:
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")

pipe = WanImageToVideoPipeline.from_pretrained(
pipe = TASK_PIPELINE[args.task].from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16
torch_dtype=torch.bfloat16,
)
pipe.scheduler.config.flow_shift = 5 # Match original implementation
pipe.scheduler.config.flow_shift = TASK_FLOW_SHIFT[args.task]
initialize_runtime_state(pipe, engine_config)
parallelize_transformer(pipe)
pipe = pipe.to(f"cuda:{local_rank}")

image = load_image(args.img_file_path)
if not args.img_file_path and args.task == "i2v":
raise ValueError("Please provide an input image path via --img_file_path. This may be a local path or a URL.")

max_area = input_config.height * input_config.width
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
is_i2v_task = args.task == "i2v" or (args.task == "ti2v" and args.img_file_path != None)
if is_i2v_task:
image = load_image(args.img_file_path)
max_area = input_config.height * input_config.width
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
if is_dp_last_group():
print("Max area is calculated from input height and width values, but the aspect ratio for the output video is retained from the input image.")
print(f"Input image resolution: {image.height}x{image.width}")
print(f"Generating a video with resolution: {height}x{width}")
else: # T2V or TI2V with no image
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = input_config.height // mod_value * mod_value
width = input_config.width // mod_value * mod_value
if height != input_config.height or width != input_config.width:
if is_dp_last_group():
print(f"Adjusting height and width to be multiples of {mod_value}. New dimensions: {height}x{width}")
image = None

def run_pipe(input_config, image):
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start = time.perf_counter()
optional_kwargs = {}
if image:
optional_kwargs["image"] = image
output = pipe(
height=height,
width=width,
image=image,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
num_frames=input_config.num_frames,
guidance_scale=input_config.guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
**optional_kwargs,
).frames[0]
end = time.perf_counter()
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
Expand All @@ -243,7 +287,9 @@ def run_pipe(input_config, image):

output = run_pipe(input_config, image)
if is_dp_last_group():
export_to_video(output, "i2v_output.mp4", fps=16)
file_name = f"{args.task}_output.mp4"
export_to_video(output, file_name, fps=TASK_FPS[args.task])
print(f"Output video saved to {file_name}")

get_runtime_state().destroy_distributed_env()

Expand Down