@@ -121,7 +121,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
121121
122122 self .attn_dropout = nn .Dropout (config .attn_pdrop )
123123 self .resid_dropout = nn .Dropout (config .resid_pdrop )
124- self .is_causal = True
124+ self .is_causal = not is_cross_attention
125125
126126 def _upcast_and_reordered_attn (self , query , key , value , attention_mask = None ):
127127 # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
@@ -234,13 +234,6 @@ def forward(
234234 if is_cross_attention :
235235 past_key_values .is_updated [self .layer_idx ] = True
236236
237- is_causal = attention_mask is None and query_states .shape [- 2 ] > 1 and not is_cross_attention
238-
239- # For flash attention backends, we must keep causal behavior even when a 2D padding mask is provided
240- # (flash attention uses `is_causal` for the triangular mask and expects an optional 2D key padding mask)
241- if self .config ._attn_implementation in {"flash_attention_2" , "flash_attention_3" } and not is_cross_attention :
242- is_causal = query_states .shape [- 2 ] > 1
243-
244237 using_eager = self .config ._attn_implementation == "eager"
245238 attention_interface : Callable = eager_attention_forward
246239 if self .config ._attn_implementation != "eager" :
@@ -258,7 +251,7 @@ def forward(
258251 value_states ,
259252 attention_mask ,
260253 dropout = self .attn_dropout .p if self .training else 0.0 ,
261- is_causal = is_causal ,
254+ is_causal = self . is_causal ,
262255 ** kwargs ,
263256 )
264257
0 commit comments