Skip to content

Commit 40dc11c

Browse files
authored
Fix Gemma (huggingface#42847)
fix
1 parent c247063 commit 40dc11c

2 files changed

Lines changed: 18 additions & 22 deletions

File tree

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -410,16 +410,14 @@ def forward(
410410
if position_ids is None:
411411
position_ids = cache_position.unsqueeze(0)
412412

413-
# It may already have been prepared by e.g. `generate`
414-
if not isinstance(causal_mask_mapping := attention_mask, dict):
415-
causal_mask_mapping = create_causal_mask(
416-
config=self.config,
417-
input_embeds=inputs_embeds,
418-
attention_mask=attention_mask,
419-
cache_position=cache_position,
420-
past_key_values=past_key_values,
421-
position_ids=position_ids,
422-
)
413+
causal_mask = create_causal_mask(
414+
config=self.config,
415+
input_embeds=inputs_embeds,
416+
attention_mask=attention_mask,
417+
cache_position=cache_position,
418+
past_key_values=past_key_values,
419+
position_ids=position_ids,
420+
)
423421

424422
# embed positions
425423
hidden_states = inputs_embeds
@@ -434,7 +432,7 @@ def forward(
434432
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
435433
hidden_states = decoder_layer(
436434
hidden_states,
437-
attention_mask=causal_mask_mapping,
435+
attention_mask=causal_mask,
438436
position_ids=position_ids,
439437
past_key_values=past_key_values,
440438
use_cache=use_cache,

src/transformers/models/gemma/modular_gemma.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,16 +267,14 @@ def forward(
267267
if position_ids is None:
268268
position_ids = cache_position.unsqueeze(0)
269269

270-
# It may already have been prepared by e.g. `generate`
271-
if not isinstance(causal_mask_mapping := attention_mask, dict):
272-
causal_mask_mapping = create_causal_mask(
273-
config=self.config,
274-
input_embeds=inputs_embeds,
275-
attention_mask=attention_mask,
276-
cache_position=cache_position,
277-
past_key_values=past_key_values,
278-
position_ids=position_ids,
279-
)
270+
causal_mask = create_causal_mask(
271+
config=self.config,
272+
input_embeds=inputs_embeds,
273+
attention_mask=attention_mask,
274+
cache_position=cache_position,
275+
past_key_values=past_key_values,
276+
position_ids=position_ids,
277+
)
280278

281279
# embed positions
282280
hidden_states = inputs_embeds
@@ -291,7 +289,7 @@ def forward(
291289
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
292290
hidden_states = decoder_layer(
293291
hidden_states,
294-
attention_mask=causal_mask_mapping,
292+
attention_mask=causal_mask,
295293
position_ids=position_ids,
296294
past_key_values=past_key_values,
297295
use_cache=use_cache,

0 commit comments

Comments
 (0)