Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/attention_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The [`AttentionInterface`] provides optimized attention implementations. It deco
| `"flash_attention_2"` | tiles computations into smaller blocks and uses fast on-chip memory |
| `"flex_attention"` | framework for specifying custom attention patterns (sparse, block-local, sliding window) without writing low-level kernels by hand |
| `"sdpa"` | built-in PyTorch implementation of [scaled dot product attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) |
| <code>"paged&#124;flash_attention_3"</code> | Paged version of FlashAttention-3 |
| <code>"paged&#124;flash_attention_2"</code> | Paged version of FlashAttention-2 |
| <code>"paged&#124;sdpa"</code> | Paged version of SDPA |
| <code>"paged&#124;eager"</code> | Paged version of eager |
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...utils.generic import is_flash_attention_requested
from ...utils.metrics import attach_tracer, traced
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import RequestState, get_device_and_memory_breakdown, logger
Expand Down Expand Up @@ -172,7 +173,7 @@ def __init__(
# Infer number of blocks and max batch tokens
page_size = self.head_dim * self.num_key_value_heads

if "flash" in self.config._attn_implementation:
if is_flash_attention_requested(self.config):
num_attention_masks = 0 # only used to compute the default memory footprint args
elif "sliding_attention" in group_types:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
is_torchdynamo_exporting,
logging,
)
from ..utils.generic import is_flash_attention_requested
from .candidate_generator import (
AssistantVocabTranslatorCache,
AssistedCandidateGenerator,
Expand Down Expand Up @@ -2172,13 +2173,13 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict[str, Any], generation_
# Finally: if we can compile, disable tokenizers parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "0"

# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# If we use FA and a static cache, we cannot compile with fullgraph
if is_flash_attention_requested(self.config):
# only raise warning if the user passed an explicit compile-config
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
"When using Flash Attention and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def flash_attention_forward(
) -> tuple[torch.Tensor, None]:
if kwargs.get("output_attentions", False):
logger.warning_once(
"`flash_attention_2` does not support `output_attentions=True`."
"Flash Attention does not support `output_attentions=True`."
" Please set your attention to `eager` if you want any of these features."
)

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .cache_utils import Cache
from .configuration_utils import PreTrainedConfig
from .utils import is_torch_xpu_available, logging
from .utils.generic import GeneralInterface
from .utils.generic import GeneralInterface, is_flash_attention_requested
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_tracing


Expand Down Expand Up @@ -1112,10 +1112,10 @@ def create_chunked_causal_mask(
if chunk_size is None:
raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set")

# Raise if using chunked attention on context too large with FA2
if config._attn_implementation == "flash_attention_2" and kv_length + kv_offset > chunk_size:
# Raise if using chunked attention on context too large with FA
if is_flash_attention_requested(config) and kv_length + kv_offset > chunk_size:
raise ValueError(
"Flash attention 2 cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
"Flash attention cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
"chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model"
)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
is_torch_xpu_available,
logging,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder, is_flash_attention_requested
from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
is_huggingface_hub_greater_or_equal,
Expand Down Expand Up @@ -1842,7 +1842,7 @@ def _check_and_adjust_attn_implementation(
)

# preload flash attention here to allow compile with fullgraph
if "flash" in applicable_attn_implementation:
if is_flash_attention_requested(requested_attention_implementation=applicable_attn_implementation):
lazy_import_flash_attention(applicable_attn_implementation)

return applicable_attn_implementation
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ class AfmoePreTrainedModel(PreTrainedModel):
"expert_bias",
]
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
supports_gradient_checkpointing = True
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/afmoe/modular_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class AfmoePreTrainedModel(PreTrainedModel):
"expert_bias",
]
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flash_attn = True
_supports_flex_attn = True
_supports_attention_backend = True
supports_gradient_checkpointing = True
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
from ...utils.generic import is_flash_attention_requested
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig


Expand Down Expand Up @@ -495,7 +496,7 @@ def forward(
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation != "flash_attention_2":
if not is_flash_attention_requested(self.config):
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

List options
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.

Examples:

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.generic import is_flash_attention_requested
from .configuration_autoformer import AutoformerConfig


Expand Down Expand Up @@ -850,7 +851,7 @@ def _update_full_mask(
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if "flash" in self.config._attn_implementation:
if is_flash_attention_requested(self.config):
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.generic import is_flash_attention_requested, maybe_autocast
from .configuration_bamba import BambaConfig


Expand Down Expand Up @@ -1259,7 +1259,7 @@ def _update_causal_mask(
past_key_values: HybridMambaAttentionDynamicCache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if is_flash_attention_requested(self.config):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
can_return_tuple,
logging,
)
from ...utils.generic import is_flash_attention_requested
from .configuration_bamba import BambaConfig


Expand Down Expand Up @@ -933,7 +934,7 @@ def _update_causal_mask(
past_key_values: HybridMambaAttentionDynamicCache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if is_flash_attention_requested(self.config):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_torch_accelerator_available,
logging,
)
from ...utils.generic import is_flash_attention_requested
from ..auto import AutoModel
from .configuration_bark import (
BarkCoarseConfig,
Expand Down Expand Up @@ -497,7 +498,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
if self.config._attn_implementation == "flash_attention_2":
if is_flash_attention_requested(self.config):
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
Expand Down Expand Up @@ -1095,7 +1096,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
if self.config._attn_implementation == "flash_attention_2":
if is_flash_attention_requested(self.config):
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
is_torch_flex_attn_available,
logging,
)
from ...utils.generic import is_flash_attention_requested
from .configuration_bloom import BloomConfig


Expand Down Expand Up @@ -568,7 +569,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if is_flash_attention_requested(self.config):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
from ...utils.generic import is_flash_attention_requested
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig


Expand Down Expand Up @@ -322,7 +323,7 @@ def forward(
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation != "flash_attention_2":
if not is_flash_attention_requested(self.config):
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_torch_flex_attn_available,
logging,
)
from ...utils.generic import is_flash_attention_requested
from .configuration_codegen import CodeGenConfig


Expand Down Expand Up @@ -428,7 +429,7 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if is_flash_attention_requested(self.config):
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
from ...utils.generic import check_model_inputs, maybe_autocast
from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast
from .configuration_deepseek_v2 import DeepseekV2Config


Expand Down Expand Up @@ -372,7 +372,7 @@ def forward(
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

attention_interface: Callable = eager_attention_forward
Expand All @@ -390,7 +390,7 @@ def forward(
**kwargs,
)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]

attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import is_grouped_mm_available, logging
from ...utils.generic import maybe_autocast
from ...utils.generic import is_flash_attention_requested, maybe_autocast
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaDecoderLayer,
Expand Down Expand Up @@ -397,7 +397,7 @@ def forward(
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

attention_interface: Callable = eager_attention_forward
Expand All @@ -415,7 +415,7 @@ def forward(
**kwargs,
)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]

attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
from ...utils.generic import check_model_inputs, maybe_autocast
from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast
from .configuration_deepseek_v3 import DeepseekV3Config


Expand Down Expand Up @@ -456,7 +456,7 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

attention_interface: Callable = eager_attention_forward
Expand All @@ -474,7 +474,7 @@ def forward(
**kwargs,
)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]

attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import is_grouped_mm_available, logging
from ...utils.generic import is_flash_attention_requested
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
Expand Down Expand Up @@ -260,7 +261,7 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

attention_interface: Callable = eager_attention_forward
Expand All @@ -278,7 +279,7 @@ def forward(
**kwargs,
)

if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]

attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
Expand Down
Loading