99 minimum_diffusers_version = get_minimum_diffusers_version ("wan" )
1010 raise ImportError (f"Please install diffusers>={ minimum_diffusers_version } to use Wan." )
1111
12- from diffusers import WanImageToVideoPipeline
12+ from diffusers import WanImageToVideoPipeline , WanPipeline
1313from diffusers .utils import export_to_video , load_image
1414from diffusers .models .modeling_outputs import Transformer2DModelOutput
1515
2828)
2929from xfuser .model_executor .models .transformers .transformer_wan import xFuserWanAttnProcessor
3030
31+ TASK_FPS = {
32+ "i2v" : 16 ,
33+ "t2v" : 16 ,
34+ "ti2v" : 24 ,
35+ }
36+
37+ TASK_FLOW_SHIFT = {
38+ "i2v" : 5 ,
39+ "t2v" : 12 ,
40+ "ti2v" : 5 ,
41+ }
42+
3143# Wrapper to only wrap the transformer in case it exists, i.e. Wan2.2
3244def maybe_transformer_2 (transformer_2 ):
3345 if transformer_2 is not None :
@@ -103,6 +115,18 @@ def new_forward(
103115 ], dim = 1 )
104116 hidden_states = torch .chunk (hidden_states , get_sequence_parallel_world_size (), dim = - 2 )[get_sequence_parallel_rank ()]
105117
118+ if ts_seq_len is not None : # (wan2.2 ti2v)
119+ temb = torch .cat ([
120+ temb ,
121+ torch .zeros (batch_size , sequence_pad_amount , temb .shape [2 ], device = temb .device , dtype = temb .dtype )
122+ ], dim = 1 )
123+ timestep_proj = torch .cat ([
124+ timestep_proj ,
125+ torch .zeros (batch_size , sequence_pad_amount , timestep_proj .shape [2 ], timestep_proj .shape [3 ], device = timestep_proj .device , dtype = timestep_proj .dtype )
126+ ], dim = 1 )
127+ temb = torch .chunk (temb , get_sequence_parallel_world_size (), dim = 1 )[get_sequence_parallel_rank ()]
128+ timestep_proj = torch .chunk (timestep_proj , get_sequence_parallel_world_size (), dim = 1 )[get_sequence_parallel_rank ()]
129+
106130 freqs_cos , freqs_sin = rotary_emb
107131
108132 def get_rotary_emb_chunk (freqs , sequence_pad_amount ):
@@ -183,47 +207,71 @@ def get_rotary_emb_chunk(freqs, sequence_pad_amount):
183207
184208def main ():
185209 parser = FlexibleArgumentParser (description = "xFuser Arguments" )
210+ parser .add_argument (
211+ "--task" ,
212+ type = str ,
213+ required = True ,
214+ choices = ["i2v" , "t2v" , "ti2v" ],
215+ help = "The task to run."
216+ )
186217 args = xFuserArgs .add_cli_args (parser ).parse_args ()
187218 engine_args = xFuserArgs .from_cli_args (args )
188219 engine_config , input_config = engine_args .create_config ()
189220 engine_config .runtime_config .dtype = torch .bfloat16
190221 local_rank = get_world_group ().local_rank
191222 assert engine_args .pipefusion_parallel_degree == 1 , "This script does not support PipeFusion."
192223
193- if not args .img_file_path :
194- raise ValueError ("Please provide an input image path via --img_file_path. This may be a local path or a URL." )
195-
196- pipe = WanImageToVideoPipeline .from_pretrained (
224+ is_i2v_task = args .task == "i2v" or (args .task == "ti2v" and args .img_file_path != None )
225+ task_pipeline = WanImageToVideoPipeline if is_i2v_task else WanPipeline
226+ pipe = task_pipeline .from_pretrained (
197227 pretrained_model_name_or_path = engine_config .model_config .model ,
198- torch_dtype = torch .bfloat16
228+ torch_dtype = torch .bfloat16 ,
199229 )
200- pipe .scheduler .config .flow_shift = 5 # Match original implementation
230+ pipe .scheduler .config .flow_shift = TASK_FLOW_SHIFT [ args . task ]
201231 initialize_runtime_state (pipe , engine_config )
202232 parallelize_transformer (pipe )
203233 pipe = pipe .to (f"cuda:{ local_rank } " )
204234
205- image = load_image (args .img_file_path )
235+ if not args .img_file_path and args .task == "i2v" :
236+ raise ValueError ("Please provide an input image path via --img_file_path. This may be a local path or a URL." )
206237
207- max_area = input_config .height * input_config .width
208- aspect_ratio = image .height / image .width
209- mod_value = pipe .vae_scale_factor_spatial * pipe .transformer .config .patch_size [1 ]
210- height = round (np .sqrt (max_area * aspect_ratio )) // mod_value * mod_value
211- width = round (np .sqrt (max_area / aspect_ratio )) // mod_value * mod_value
212- image = image .resize ((width , height ))
238+ if is_i2v_task :
239+ image = load_image (args .img_file_path )
240+ max_area = input_config .height * input_config .width
241+ aspect_ratio = image .height / image .width
242+ mod_value = pipe .vae_scale_factor_spatial * pipe .transformer .config .patch_size [1 ]
243+ height = round (np .sqrt (max_area * aspect_ratio )) // mod_value * mod_value
244+ width = round (np .sqrt (max_area / aspect_ratio )) // mod_value * mod_value
245+ image = image .resize ((width , height ))
246+ if is_dp_last_group ():
247+ print ("Max area is calculated from input height and width values, but the aspect ratio for the output video is retained from the input image." )
248+ print (f"Input image resolution: { image .height } x{ image .width } " )
249+ print (f"Generating a video with resolution: { height } x{ width } " )
250+ else : # T2V or TI2V with no image
251+ mod_value = pipe .vae_scale_factor_spatial * pipe .transformer .config .patch_size [1 ]
252+ height = input_config .height // mod_value * mod_value
253+ width = input_config .width // mod_value * mod_value
254+ if height != input_config .height or width != input_config .width :
255+ if is_dp_last_group ():
256+ print (f"Adjusting height and width to be multiples of { mod_value } . New dimensions: { height } x{ width } " )
257+ image = None
213258
214259 def run_pipe (input_config , image ):
215260 torch .cuda .reset_peak_memory_stats ()
216261 torch .cuda .synchronize ()
217262 start = time .perf_counter ()
263+ optional_kwargs = {}
264+ if image :
265+ optional_kwargs ["image" ] = image
218266 output = pipe (
219267 height = height ,
220268 width = width ,
221- image = image ,
222269 prompt = input_config .prompt ,
223270 num_inference_steps = input_config .num_inference_steps ,
224271 num_frames = input_config .num_frames ,
225272 guidance_scale = input_config .guidance_scale ,
226273 generator = torch .Generator (device = "cuda" ).manual_seed (input_config .seed ),
274+ ** optional_kwargs ,
227275 ).frames [0 ]
228276 end = time .perf_counter ()
229277 peak_memory = torch .cuda .max_memory_allocated (device = f"cuda:{ local_rank } " )
@@ -243,7 +291,9 @@ def run_pipe(input_config, image):
243291
244292 output = run_pipe (input_config , image )
245293 if is_dp_last_group ():
246- export_to_video (output , "i2v_output.mp4" , fps = 16 )
294+ file_name = f"{ args .task } _output.mp4"
295+ export_to_video (output , file_name , fps = TASK_FPS [args .task ])
296+ print (f"Output video saved to { file_name } " )
247297
248298 get_runtime_state ().destroy_distributed_env ()
249299
0 commit comments