diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1eae97ff1..562755ba2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -774,9 +774,12 @@ def LlamaModel_fast_forward( self.SWA_mask = True self.GA_mask = False elif attention_mask is not None: - # Fixes https://github.com/unslothai/unsloth/issues/853 # Unsloth needs a 2D mask, not a [2, 1, n, n] mask! + + # https://github.com/pytorch/pytorch/issues/103749 + # Need to convert to float and not using bool + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length),