Skip to content

Conversation

@eppaneamd
Copy link
Contributor

In HunyuanVideo example, the model is compiled as follows:

        pipe.transformer = torch.compile(pipe.transformer,
                                         mode="max-autotune-no-cudagraphs")

However, when using compile and model cpu offloading together, this causes:

[rank6]: Traceback (most recent call last):
[rank6]:   File "/home/repos/xDiT/examples/hunyuan_video_usp_example.py", line 328, in <module>
[rank6]:     main()
[rank6]:   File "/home/repos/xDiT/examples/hunyuan_video_usp_example.py", line 277, in main
[rank6]:     output = pipe(
[rank6]:   File "/home/repos/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
[rank6]:     return func(*args, **kwargs)
[rank6]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py", line 682, in __call__
[rank6]:     self.maybe_free_model_hooks()
[rank6]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py", line 1118, in maybe_free_model_hooks
[rank6]:     self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
[rank6]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py", line 1049, in enable_model_cpu_offload
[rank6]:     self.remove_all_hooks()
[rank6]:   File "/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py", line 1016, in remove_all_hooks
[rank6]:     accelerate.hooks.remove_hook_from_module(model, recurse=True)
[rank6]:   File "/home/repos/accelerate/src/accelerate/hooks.py", line 203, in remove_hook_from_module
[rank6]:     delattr(module, "_hf_hook")
[rank6]:   File "/home/repos/pytorch/torch/nn/modules/module.py", line 2040, in __delattr__
[rank6]:     super().__delattr__(name)
[rank6]: AttributeError: _hf_hook

Repro command:

torchrun --nproc_per_node=8 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo \
--prompt "In the large cage, two puppies were wagging their tails at each other." \
--height 544 --width 960 --num_frames 129 --num_inference_steps 30 --warmup_steps 1 \
--ulysses_degree 8 --enable_tiling --enable_model_cpu_offload --use_torch_compile

This PR proposes modifying the model compilation to:

        pipe.transformer.compile()

Which does not cause the error. Also max-autotune is rather time-consuming process thus opting for default mode.

Additionally, there are currently some limitations for HunyuanVideo which are outlined in the README. They can be removed as soon as they have a fix implemented.

@StrongerXi
Copy link

Related: pytorch/pytorch#150711

If you upgrade pytorch, this should be fixed as well. But yes in general we recommend model.compile over torch.compile(model).

@feifeibear feifeibear merged commit 50a1a3b into xdit-project:main Sep 8, 2025
@jcaraban jcaraban mentioned this pull request Nov 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants