Skip to content

Commit 2edb434

Browse files
define is_causal in init
1 parent 72edb13 commit 2edb434

File tree

2 files changed

+4
-18
lines changed

2 files changed

+4
-18
lines changed

src/transformers/models/decision_transformer/modeling_decision_transformer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
121121

122122
self.attn_dropout = nn.Dropout(config.attn_pdrop)
123123
self.resid_dropout = nn.Dropout(config.resid_pdrop)
124-
self.is_causal = True
124+
self.is_causal = not is_cross_attention
125125

126126
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
127127
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
@@ -234,13 +234,6 @@ def forward(
234234
if is_cross_attention:
235235
past_key_values.is_updated[self.layer_idx] = True
236236

237-
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
238-
239-
# For flash attention backends, we must keep causal behavior even when a 2D padding mask is provided
240-
# (flash attention uses `is_causal` for the triangular mask and expects an optional 2D key padding mask)
241-
if self.config._attn_implementation in {"flash_attention_2", "flash_attention_3"} and not is_cross_attention:
242-
is_causal = query_states.shape[-2] > 1
243-
244237
using_eager = self.config._attn_implementation == "eager"
245238
attention_interface: Callable = eager_attention_forward
246239
if self.config._attn_implementation != "eager":
@@ -258,7 +251,7 @@ def forward(
258251
value_states,
259252
attention_mask,
260253
dropout=self.attn_dropout.p if self.training else 0.0,
261-
is_causal=is_causal,
254+
is_causal=self.is_causal,
262255
**kwargs,
263256
)
264257

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
130130

131131
self.attn_dropout = nn.Dropout(config.attn_pdrop)
132132
self.resid_dropout = nn.Dropout(config.resid_pdrop)
133-
self.is_causal = True
133+
self.is_causal = not is_cross_attention
134134

135135
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
136136
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
@@ -243,13 +243,6 @@ def forward(
243243
if is_cross_attention:
244244
past_key_values.is_updated[self.layer_idx] = True
245245

246-
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
247-
248-
# For flash attention backends, we must keep causal behavior even when a 2D padding mask is provided
249-
# (flash attention uses `is_causal` for the triangular mask and expects an optional 2D key padding mask)
250-
if self.config._attn_implementation in {"flash_attention_2", "flash_attention_3"} and not is_cross_attention:
251-
is_causal = query_states.shape[-2] > 1
252-
253246
using_eager = self.config._attn_implementation == "eager"
254247
attention_interface: Callable = eager_attention_forward
255248
if self.config._attn_implementation != "eager":
@@ -267,7 +260,7 @@ def forward(
267260
value_states,
268261
attention_mask,
269262
dropout=self.attn_dropout.p if self.training else 0.0,
270-
is_causal=is_causal,
263+
is_causal=self.is_causal,
271264
**kwargs,
272265
)
273266

0 commit comments

Comments
 (0)