From cccdca3e6e3b8400f0e27e8e7787a51052741353 Mon Sep 17 00:00:00 2001 From: Erland366 Date: Thu, 20 Feb 2025 00:18:46 +0400 Subject: [PATCH] Convert mask to float --- unsloth/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1eae97ff1..562755ba2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -774,9 +774,12 @@ def LlamaModel_fast_forward( self.SWA_mask = True self.GA_mask = False elif attention_mask is not None: - # Fixes https://github.com/unslothai/unsloth/issues/853 # Unsloth needs a 2D mask, not a [2, 1, n, n] mask! + + # https://github.com/pytorch/pytorch/issues/103749 + # Need to convert to float and not using bool + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length),