diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index 03856f11..cc354aa3 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -92,14 +92,8 @@ def new_forward( get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] - encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) - encoder_hidden_states_indices = torch.arange( - encoder_hidden_states.shape[1], - device=encoder_hidden_states.device) - encoder_hidden_states_indices = encoder_hidden_states_indices[ - encoder_attention_mask] - encoder_hidden_states = encoder_hidden_states[ - ..., encoder_hidden_states_indices, :] + encoder_attention_mask = encoder_attention_mask.to(torch.bool).any(dim=0) + encoder_hidden_states = encoder_hidden_states[:, encoder_attention_mask, :] if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size( ) != 0: get_runtime_state().split_text_embed_in_sp = False @@ -234,7 +228,7 @@ def main(): height=input_config.height, width=input_config.width, num_frames=input_config.num_frames, - batch_size=1, + batch_size=input_config.batch_size, num_inference_steps=input_config.num_inference_steps, split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, ) @@ -297,7 +291,7 @@ def main(): guidance_scale=input_config.guidance_scale, generator=torch.Generator(device="cuda").manual_seed( input_config.seed), - ).frames[0] + ) end_time = time.time() elapsed_time = end_time - start_time @@ -311,9 +305,10 @@ def main(): ) if is_dp_last_group(): resolution = f"{input_config.width}x{input_config.height}" - output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4" - export_to_video(output, output_filename, fps=15) - print(f"output saved to {output_filename}") + for idx, frames in enumerate(output.frames, start=1): + output_filename = f"results/hunyuan_video_{idx:02d}_{parallel_info}_{resolution}.mp4" + export_to_video(frames, output_filename, fps=15) + print(f"output saved to {output_filename}") if get_world_group().rank == get_world_group().world_size - 1: print(