diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index cc7b87cb20..cd6f6b4c3a 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,7 +6,6 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -import functools from collections.abc import Callable from typing import ClassVar, NamedTuple @@ -164,22 +163,19 @@ def forward( return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) -# We cannot do inner function/closure because we won't be able to cache it -- -# if we an inner function, a new closure will be created every time -# `get_causal_mask_mod` is called. -def _causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor -) -> torch.Tensor: - """Causal mask that prevents attention to future tokens.""" - return q_idx >= kv_idx - - def get_causal_mask_mod() -> _mask_mod_signature: """Returns a causal mask modifier for flex attention. Returns: A mask modifier function that implements causal masking. """ + + def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + return _causal_mask @@ -268,13 +264,8 @@ def sliding_window_mod( _compiled_create_block_mask = torch.compile(create_block_mask) -@functools.lru_cache(4) def create_attention_mask(*args, **kwargs): - """Create an attention mask using compiled create_block_mask. - - This function is cached to avoid recreating BlockMasks for the same - arguments. - """ + """Create an attention mask using compiled create_block_mask.""" return _compiled_create_block_mask(*args, **kwargs)