Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def get_bw_flops(model_fn):
model_selective_ac,
ac_config_no_force,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
flops_selective_ac = get_bw_flops(model_selective_ac)
Expand All @@ -106,7 +105,6 @@ def get_bw_flops(model_fn):
model_with_force_first,
ac_config_with_force_first,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
flops_with_force_first = get_bw_flops(model_with_force_first)
Expand All @@ -123,7 +121,6 @@ def get_bw_flops(model_fn):
model_with_force_last,
ac_config_with_force_last,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
flops_with_force_last = get_bw_flops(model_with_force_last)
Expand All @@ -138,7 +135,6 @@ def get_bw_flops(model_fn):
model_with_full_ac,
ac_config_full_ac,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
flops_full_ac = get_bw_flops(model_with_full_ac)
Expand Down Expand Up @@ -181,7 +177,6 @@ def get_act_mem(model_fn):
model_selective_ac,
ac_config_no_force,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
mem_selective_ac = get_act_mem(model_selective_ac)
Expand All @@ -198,7 +193,6 @@ def get_act_mem(model_fn):
model_with_force_first,
ac_config_with_force_first,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
mem_with_force_first = get_act_mem(model_with_force_first)
Expand All @@ -214,7 +208,6 @@ def get_act_mem(model_fn):
model_with_force_last,
ac_config_with_force_last,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
mem_with_force_last = get_act_mem(model_with_force_last)
Expand All @@ -228,7 +221,6 @@ def get_act_mem(model_fn):
model_with_full_ac,
ac_config_full_ac,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
mem_full_ac = get_act_mem(model_with_full_ac)
Expand All @@ -255,7 +247,6 @@ def test_correctness(self):
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
),
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
model_force_first = ToyModule()
Expand All @@ -268,7 +259,6 @@ def test_correctness(self):
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
),
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)

Expand All @@ -282,7 +272,6 @@ def test_correctness(self):
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
),
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)

Expand Down
108 changes: 4 additions & 104 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.tools.logging import logger, warn_once
from torchtitan.tools.logging import logger


_layer_sac_count = 0
Expand Down Expand Up @@ -155,88 +155,12 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
)


def _apply_op_sac_to_transformer_block_with_flex(
module: nn.Module,
ac_config: ACConfig,
*,
base_fqn: str | None = None,
model_compile_enabled: bool = False,
op_sac_save_list: set[torch._ops.OpOverload],
) -> nn.Module:
"""Apply SAC to the transformer block that uses FlexAttention.

Args:
module (nn.Module): The transformer block to apply SAC to.
ac_config (ACConfig): The Activation Checkpoint config.
base_fqn (str, optional): The base fqn of the module. Defaults to None.
model_compile_enabled (bool): Whether model compilation is enabled.
Defaults to False.
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
of recomputing.

Returns:
nn.Module: The transformer block with SAC applied.
"""

warn_once(
logger,
(
"Flex Attention requires compilation for good performance.\n"
"Thus, torch.compile is always used for Flex Attention, "
"regardless of the compile.enable flag.\n"
"However, when selective activation checkpointing (SAC) is enabled, "
"torch.compile may be invalidated:\n"
"1. If compile.enable is False, SAC will ignore any torch.compile "
"inside the SAC region.\n"
"2. If compile.enable is True but the transformer block contains an MoE module.\n\n"
"For both cases, we will not wrap the entire TransformerBlock with SAC:\n"
" - For case 1: SAC will be used for MoE and FeedForward modules, "
"while full AC will be used for the Attention module.\n"
" - For case 2: SAC will be applied to MoE and Attention modules if the block "
"is sparse. But we still apply SAC to an entire dense block.\n"
),
)

def wrap_submodule(name: str, full_ac: bool = False) -> None:
submodule = getattr(module, name)
if full_ac:
submodule = _apply_full_ac(submodule, ac_config)
else:
submodule = _apply_op_sac(
submodule,
ac_config,
base_fqn=f"{base_fqn}.{name}" if base_fqn else name,
op_sac_save_list=op_sac_save_list,
)
module.register_module(name, submodule)

if hasattr(module, "moe"):
wrap_submodule("moe", full_ac=False)
if model_compile_enabled:
wrap_submodule("attention", full_ac=False)
else:
wrap_submodule("attention", full_ac=True)
else:
if model_compile_enabled:
module = _apply_op_sac(
module,
ac_config,
base_fqn=base_fqn,
op_sac_save_list=op_sac_save_list,
)
else:
wrap_submodule("feed_forward", full_ac=False)
wrap_submodule("attention", full_ac=True)
return module


def _apply_ac_to_transformer_block(
module: nn.Module,
ac_config: ACConfig,
*,
base_fqn: str | None = None,
model_compile_enabled: bool = False,
use_flex_attn: bool = False,
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
) -> nn.Module:
valid_ac_modes = ("full", "selective")
Expand All @@ -259,26 +183,9 @@ def _apply_ac_to_transformer_block(

if use_op_sac:
op_sac_save_list = op_sac_save_list or set()
if use_flex_attn:
"""
For Flex Attention, we need to apply SAC carefully to avoid invalidating
torch.compile. Any torch.compile inside the SAC region will be ignored,
and any torch.compile outside the SAC region will also be ignored if the
SAC region contains a graph break (e.g., MoE).

TODO: remove this once SAC issues are resolved.
"""
return _apply_op_sac_to_transformer_block_with_flex(
module,
ac_config,
base_fqn=base_fqn,
model_compile_enabled=model_compile_enabled,
op_sac_save_list=op_sac_save_list,
)
else:
return _apply_op_sac(
module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list
)
return _apply_op_sac(
module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list
)

return _apply_layer_sac(module, ac_config)

Expand All @@ -288,21 +195,15 @@ def apply_ac(
ac_config: ACConfig,
*,
model_compile_enabled: bool = False,
use_flex_attn: bool = False,
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
base_folder: str = "",
) -> None:
"""Apply activation checkpointing to the model.

Note that SAC, Flex Attention and model compilation have some conflicts.
We explicitly ask the user to pass these configs to warn as the wrapping
will be different.

Args:
model (nn.Module): The model to apply activation checkpointing to.
ac_config (ACConfig): The activation checkpointing config.
model_compile_enabled (bool): Whether torch.compile is enabled for the model.
use_flex_attn (bool): Whether flex attention is enabled for the model.
op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead
of recomputing.
Returns:
Expand All @@ -326,7 +227,6 @@ def apply_ac(
ac_config,
base_fqn=f"layers.{layer_id}",
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=op_sac_save_list,
)
model.layers.register_module(layer_id, transformer_block)
Expand Down
4 changes: 1 addition & 3 deletions torchtitan/experiments/gpt_oss/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch._higher_order_ops.inductor_compiled_code,
}


Expand Down Expand Up @@ -110,14 +111,11 @@ def parallelize_gptoss(
job_config.compile.enable and "model" in job_config.compile.components
)

attn_type = getattr(model.model_args, "attn_type", "sdpa")
use_flex_attn = attn_type == "flex"
if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
)

Expand Down
4 changes: 1 addition & 3 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch.ops.torch_attn._varlen_attn,
torch._higher_order_ops.inductor_compiled_code,
}


Expand Down Expand Up @@ -106,16 +107,13 @@ def parallelize_llama(
maybe_enable_async_tp(job_config, tp_mesh)

if job_config.activation_checkpoint.mode != "none":
attn_type = getattr(model.model_args, "attn_type", "sdpa")
use_flex_attn = attn_type == "flex"
model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/experiments/vlm/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ def parallelize_vlm(
model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
use_flex_attn = attn_type == "flex"
if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
)
apply_ac(model.encoder, job_config.activation_checkpoint)
Expand Down
8 changes: 7 additions & 1 deletion torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ class FlexAttentionWrapper(torch.nn.Module):
"""

_compiled_flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
flex_attention,
options={
"wrap_inductor_compiled_regions": True,
"max_autotune": True,
"coordinate_descent_tuning": True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: is this coordinate_descent_tuning also part of the "mode=max-autotune-no-cudagraphs" -> "options={...}" change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to ask: what's the context of this change to "options={}"?

Copy link
Contributor Author

@fegin fegin Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, https://github.com/pytorch/pytorch/blob/cf7bab873fa55051e1806f8db0c3f90dea452ac5/torch/_inductor/__init__.py#L361

We cannot do mode and options at the same. torch.compile forbid it. According to here, max-autotune-no-cudagraphs equals to these three options.

"triton.cudagraphs": False,
},
)

def forward(
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch._higher_order_ops.inductor_compiled_code,
}


Expand All @@ -65,7 +66,6 @@ def parallelize_deepseekv3(
"""

attn_type = getattr(model.model_args, "attn_type", "sdpa")
use_flex_attn = attn_type == "flex"
if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa":
raise NotImplementedError("CP support is only supported for SDPA.")

Expand Down Expand Up @@ -115,7 +115,6 @@ def parallelize_deepseekv3(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)
Expand Down
4 changes: 1 addition & 3 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
torch.ops.torch_attn._varlen_attn.default,
torch._higher_order_ops.inductor_compiled_code,
}


Expand Down Expand Up @@ -95,14 +96,11 @@ def parallelize_llama(
job_config.compile.enable and "model" in job_config.compile.components
)

attn_type = getattr(model.model_args, "attn_type", "sdpa")
use_flex_attn = attn_type == "flex"
if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended? I think it's OK to change this but want to confirm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, accidentally commit this.


[compile]
enable=false
Expand Down
Loading
Loading