diff --git a/autoparallel/activation_checkpointing.py b/autoparallel/activation_checkpointing.py index 1d960da..9f6de73 100644 --- a/autoparallel/activation_checkpointing.py +++ b/autoparallel/activation_checkpointing.py @@ -454,5 +454,6 @@ def ac_joint_pass( torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, } _apply_ac_policy(graph, save_list=save_list)