diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index f309172173..2b05505e4a 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -88,7 +88,6 @@ def get_bw_flops(model_fn): model_selective_ac, ac_config_no_force, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_selective_ac = get_bw_flops(model_selective_ac) @@ -106,7 +105,6 @@ def get_bw_flops(model_fn): model_with_force_first, ac_config_with_force_first, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_with_force_first = get_bw_flops(model_with_force_first) @@ -123,7 +121,6 @@ def get_bw_flops(model_fn): model_with_force_last, ac_config_with_force_last, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_with_force_last = get_bw_flops(model_with_force_last) @@ -138,7 +135,6 @@ def get_bw_flops(model_fn): model_with_full_ac, ac_config_full_ac, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_full_ac = get_bw_flops(model_with_full_ac) @@ -181,7 +177,6 @@ def get_act_mem(model_fn): model_selective_ac, ac_config_no_force, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_selective_ac = get_act_mem(model_selective_ac) @@ -198,7 +193,6 @@ def get_act_mem(model_fn): model_with_force_first, ac_config_with_force_first, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_with_force_first = get_act_mem(model_with_force_first) @@ -214,7 +208,6 @@ def get_act_mem(model_fn): model_with_force_last, ac_config_with_force_last, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_with_force_last = get_act_mem(model_with_force_last) @@ -228,7 +221,6 @@ def get_act_mem(model_fn): model_with_full_ac, ac_config_full_ac, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_full_ac = get_act_mem(model_with_full_ac) @@ -255,7 +247,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=[], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) model_force_first = ToyModule() @@ -268,7 +259,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) @@ -282,7 +272,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 8359f71730..0eecde9052 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -17,7 +17,7 @@ ) from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.tools.logging import logger, warn_once +from torchtitan.tools.logging import logger _layer_sac_count = 0 @@ -155,88 +155,12 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ) -def _apply_op_sac_to_transformer_block_with_flex( - module: nn.Module, - ac_config: ACConfig, - *, - base_fqn: str | None = None, - model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload], -) -> nn.Module: - """Apply SAC to the transformer block that uses FlexAttention. - - Args: - module (nn.Module): The transformer block to apply SAC to. - ac_config (ACConfig): The Activation Checkpoint config. - base_fqn (str, optional): The base fqn of the module. Defaults to None. - model_compile_enabled (bool): Whether model compilation is enabled. - Defaults to False. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. - - Returns: - nn.Module: The transformer block with SAC applied. - """ - - warn_once( - logger, - ( - "Flex Attention requires compilation for good performance.\n" - "Thus, torch.compile is always used for Flex Attention, " - "regardless of the compile.enable flag.\n" - "However, when selective activation checkpointing (SAC) is enabled, " - "torch.compile may be invalidated:\n" - "1. If compile.enable is False, SAC will ignore any torch.compile " - "inside the SAC region.\n" - "2. If compile.enable is True but the transformer block contains an MoE module.\n\n" - "For both cases, we will not wrap the entire TransformerBlock with SAC:\n" - " - For case 1: SAC will be used for MoE and FeedForward modules, " - "while full AC will be used for the Attention module.\n" - " - For case 2: SAC will be applied to MoE and Attention modules if the block " - "is sparse. But we still apply SAC to an entire dense block.\n" - ), - ) - - def wrap_submodule(name: str, full_ac: bool = False) -> None: - submodule = getattr(module, name) - if full_ac: - submodule = _apply_full_ac(submodule, ac_config) - else: - submodule = _apply_op_sac( - submodule, - ac_config, - base_fqn=f"{base_fqn}.{name}" if base_fqn else name, - op_sac_save_list=op_sac_save_list, - ) - module.register_module(name, submodule) - - if hasattr(module, "moe"): - wrap_submodule("moe", full_ac=False) - if model_compile_enabled: - wrap_submodule("attention", full_ac=False) - else: - wrap_submodule("attention", full_ac=True) - else: - if model_compile_enabled: - module = _apply_op_sac( - module, - ac_config, - base_fqn=base_fqn, - op_sac_save_list=op_sac_save_list, - ) - else: - wrap_submodule("feed_forward", full_ac=False) - wrap_submodule("attention", full_ac=True) - return module - - def _apply_ac_to_transformer_block( module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, - use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") @@ -259,26 +183,9 @@ def _apply_ac_to_transformer_block( if use_op_sac: op_sac_save_list = op_sac_save_list or set() - if use_flex_attn: - """ - For Flex Attention, we need to apply SAC carefully to avoid invalidating - torch.compile. Any torch.compile inside the SAC region will be ignored, - and any torch.compile outside the SAC region will also be ignored if the - SAC region contains a graph break (e.g., MoE). - - TODO: remove this once SAC issues are resolved. - """ - return _apply_op_sac_to_transformer_block_with_flex( - module, - ac_config, - base_fqn=base_fqn, - model_compile_enabled=model_compile_enabled, - op_sac_save_list=op_sac_save_list, - ) - else: - return _apply_op_sac( - module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list - ) + return _apply_op_sac( + module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list + ) return _apply_layer_sac(module, ac_config) @@ -288,21 +195,15 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. - Note that SAC, Flex Attention and model compilation have some conflicts. - We explicitly ask the user to pass these configs to warn as the wrapping - will be different. - Args: model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - use_flex_attn (bool): Whether flex attention is enabled for the model. op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. Returns: @@ -326,7 +227,6 @@ def apply_ac( ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=op_sac_save_list, ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 232cba9ff7..4d1177d1ab 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -47,6 +47,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -110,14 +111,11 @@ def parallelize_gptoss( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, ) diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index bd9c936b78..484d3d4747 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -34,6 +34,7 @@ torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, torch.ops.torch_attn._varlen_attn, + torch._higher_order_ops.inductor_compiled_code, } @@ -106,8 +107,6 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -115,7 +114,6 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index d418ad6edd..b6ada94d00 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -58,13 +58,11 @@ def parallelize_vlm( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, ) apply_ac(model.encoder, job_config.activation_checkpoint) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index cc7b87cb20..663ce54010 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -97,7 +97,14 @@ class FlexAttentionWrapper(torch.nn.Module): """ _compiled_flex_attn: ClassVar[Callable] = torch.compile( - flex_attention, mode="max-autotune-no-cudagraphs" + flex_attention, + # This options also encapsulate max-autotune-no-cudagraphs. + options={ + "wrap_inductor_compiled_regions": True, + "max_autotune": True, + "coordinate_descent_tuning": True, + "triton.cudagraphs": False, + }, ) def forward( diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 69273654e3..98db56b135 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -44,6 +44,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -65,7 +66,6 @@ def parallelize_deepseekv3( """ attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": raise NotImplementedError("CP support is only supported for SDPA.") @@ -115,7 +115,6 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 1c381883b1..52a2dfe7e2 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -45,6 +45,7 @@ torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -95,14 +96,11 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 28418d842e..0b15e0c9eb 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -52,6 +52,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -116,14 +117,11 @@ def parallelize_llama( model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) - attn_type = getattr(model.model_args, "attn_type", "sdpa") - use_flex_attn = attn_type == "flex" if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 12aca42777..6bb9eb5204 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -47,6 +47,7 @@ torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -116,7 +117,6 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=attn_type == "flex", op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, )