diff --git a/tests/utils/test_attention_masks.py b/tests/utils/test_attention_masks.py
new file mode 100644
index 000000000..dbd541961
--- /dev/null
+++ b/tests/utils/test_attention_masks.py
@@ -0,0 +1,272 @@
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program. If not, see .
+
+"""Unit tests for packed-attention mask helpers with sliding-window logic."""
+
+import math
+
+import torch
+
+from unsloth.utils import attention_dispatch
+from unsloth.utils import packing as packing_utils
+
+
+def _make_seq_info(lengths):
+ lengths = torch.tensor(lengths, dtype = torch.int32)
+ cu = torch.cat(
+ [
+ torch.zeros(1, dtype = torch.int32),
+ torch.cumsum(lengths, dim = 0, dtype = torch.int32),
+ ]
+ )
+ max_len = int(lengths.max().item())
+ return lengths, cu, max_len
+
+
+def test_sdpa_packed_attention_mask_sliding_window():
+ seq_info = _make_seq_info([5, 3])
+ mask = packing_utils.build_sdpa_packed_attention_mask(
+ seq_info,
+ dtype = torch.float32,
+ device = torch.device("cpu"),
+ sliding_window = 3,
+ )
+
+ assert mask.shape == (1, 1, 8, 8)
+
+ block_first = mask[0, 0, :5, :5]
+ upper = torch.triu(torch.ones_like(block_first), diagonal = 1).bool()
+ assert torch.all(block_first[upper] == float("-inf"))
+ assert block_first[3, 0].item() == float("-inf")
+ assert block_first[4, 1].item() == float("-inf")
+ assert block_first[4, 2].item() > -math.inf
+ assert mask[0, 0, 0, 6].item() == float("-inf")
+
+
+def test_xformers_block_mask_sliding_window(monkeypatch):
+ class _FakeMask:
+ def __init__(self, lengths, window = None):
+ self.lengths = lengths
+ self.window = window
+
+ @classmethod
+ def from_seqlens(cls, lengths):
+ return cls(tuple(lengths))
+
+ def make_local_attention(self, window_size):
+ return _FakeMask(self.lengths, window = window_size)
+
+ monkeypatch.setattr(packing_utils, "_XFormersBlockMask", _FakeMask, raising = False)
+
+ seq_info = _make_seq_info([4, 4])
+ mask = packing_utils.build_xformers_block_causal_mask(
+ seq_info,
+ sliding_window = 2,
+ )
+
+ assert isinstance(mask, _FakeMask)
+ assert mask.window == 2
+
+
+def test_run_attention_sdpa_passes_sliding_window(monkeypatch):
+ seq_info = _make_seq_info([3, 2])
+ sliding_window = 2
+
+ original_builder = attention_dispatch.build_sdpa_packed_attention_mask
+ captured = {}
+
+ def _capture_builder(seq_info_arg, *, dtype, device, sliding_window = None):
+ captured["window"] = sliding_window
+ return original_builder(
+ seq_info_arg,
+ dtype = dtype,
+ device = device,
+ sliding_window = sliding_window,
+ )
+
+ monkeypatch.setattr(
+ attention_dispatch,
+ "build_sdpa_packed_attention_mask",
+ _capture_builder,
+ )
+
+ def _fake_sdpa(Q, K, V, **kwargs):
+ captured["mask"] = kwargs.get("attn_mask")
+ return torch.zeros_like(Q)
+
+ monkeypatch.setattr(attention_dispatch, "scaled_dot_product_attention", _fake_sdpa)
+
+ config = attention_dispatch.AttentionConfig(
+ backend = attention_dispatch.SDPA,
+ n_kv_heads = 1,
+ n_groups = 1,
+ )
+
+ context = attention_dispatch.AttentionContext(
+ bsz = 1,
+ q_len = 5,
+ kv_seq_len = 5,
+ n_heads = 1,
+ head_dim = 1,
+ requires_grad = False,
+ seq_info = seq_info,
+ attention_mask = None,
+ causal_mask = None,
+ sliding_window = sliding_window,
+ )
+
+ Q = torch.zeros(1, 1, 5, 1)
+ K = torch.zeros_like(Q)
+ V = torch.zeros_like(Q)
+
+ attention_dispatch.run_attention(
+ config = config,
+ context = context,
+ Q = Q,
+ K = K,
+ V = V,
+ )
+
+ assert captured["window"] == sliding_window
+ mask = captured["mask"]
+ assert mask is not None and mask.shape == (1, 1, 5, 5)
+ assert mask[0, 0, 4, 1].item() == float("-inf")
+
+
+def test_run_attention_xformers_passes_sliding_window(monkeypatch):
+ seq_info = _make_seq_info([4])
+ sliding_window = 3
+
+ class _FakeBias:
+ pass
+
+ captured = {}
+
+ def _fake_builder(seq_info_arg, *, sliding_window = None, base_mask = None):
+ captured["window"] = sliding_window
+ captured["base"] = base_mask
+ return _FakeBias()
+
+ def _fake_attention(Q, K, V, attn_bias = None, **_):
+ captured["bias"] = attn_bias
+ return torch.zeros_like(Q)
+
+ monkeypatch.setattr(
+ attention_dispatch, "build_xformers_block_causal_mask", _fake_builder
+ )
+ monkeypatch.setattr(
+ attention_dispatch, "xformers_attention", _fake_attention, raising = False
+ )
+ monkeypatch.setattr(
+ attention_dispatch, "XFORMERS_BLOCK_DIAG_CLS", _FakeBias, raising = False
+ )
+
+ config = attention_dispatch.AttentionConfig(
+ backend = attention_dispatch.XFORMERS,
+ n_kv_heads = 1,
+ n_groups = 1,
+ )
+
+ context = attention_dispatch.AttentionContext(
+ bsz = 1,
+ q_len = 4,
+ kv_seq_len = 4,
+ n_heads = 1,
+ head_dim = 1,
+ requires_grad = False,
+ seq_info = seq_info,
+ attention_mask = None,
+ causal_mask = None,
+ sliding_window = sliding_window,
+ )
+
+ Q = torch.zeros(1, 1, 4, 1)
+ K = torch.zeros_like(Q)
+ V = torch.zeros_like(Q)
+
+ attention_dispatch.run_attention(
+ config = config,
+ context = context,
+ Q = Q,
+ K = K,
+ V = V,
+ )
+
+ assert captured["window"] == sliding_window
+ assert isinstance(captured["bias"], _FakeBias)
+
+
+def test_run_attention_flash_varlen_receives_window_and_softcap(monkeypatch):
+ seq_info = _make_seq_info([4])
+ sliding_window = 3
+ softcap = 0.5
+ window_tuple = (sliding_window, sliding_window)
+
+ captured = {}
+
+ def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):
+ captured["kwargs"] = kwargs
+ return torch.zeros_like(Q)
+
+ monkeypatch.setattr(
+ attention_dispatch,
+ "flash_attn_varlen_func",
+ _fake_flash_varlen,
+ )
+ monkeypatch.setattr(attention_dispatch, "HAS_FLASH_ATTENTION", True)
+
+ config = attention_dispatch.AttentionConfig(
+ backend = attention_dispatch.FLASH_VARLEN,
+ n_kv_heads = 1,
+ n_groups = 1,
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "softmax_scale": 1.0,
+ "causal": True,
+ "softcap": softcap,
+ "window_size": window_tuple,
+ },
+ )
+
+ context = attention_dispatch.AttentionContext(
+ bsz = 1,
+ q_len = 4,
+ kv_seq_len = 4,
+ n_heads = 1,
+ head_dim = 2,
+ requires_grad = False,
+ seq_info = seq_info,
+ attention_mask = None,
+ causal_mask = None,
+ sliding_window = sliding_window,
+ )
+
+ Q = torch.zeros(1, 1, 4, 2)
+ K = torch.zeros_like(Q)
+ V = torch.zeros_like(Q)
+
+ attention_dispatch.run_attention(
+ config = config,
+ context = context,
+ Q = Q,
+ K = K,
+ V = V,
+ )
+
+ assert captured["kwargs"]["softcap"] == softcap
+ assert captured["kwargs"]["window_size"] == window_tuple
+
+
+"""Unit tests for packed-attention mask helpers with sliding-window logic."""
diff --git a/tests/utils/test_packing.py b/tests/utils/test_packing.py
new file mode 100644
index 000000000..e9c3264ac
--- /dev/null
+++ b/tests/utils/test_packing.py
@@ -0,0 +1,341 @@
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program. If not, see .
+
+from unsloth import FastLanguageModel
+from unsloth.utils import attention_dispatch as attention_dispatch_utils
+from unsloth.utils.packing import (
+ configure_sample_packing,
+ enable_sample_packing,
+ mask_packed_sequence_boundaries,
+)
+
+from collections.abc import Iterable
+from contextlib import ExitStack
+from types import SimpleNamespace
+from unittest.mock import patch
+
+import pytest
+import torch
+from datasets import Dataset
+from trl import SFTConfig, SFTTrainer
+from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
+
+
+def _build_packed_training_setup(tmp_path, device):
+ dtype = None
+ if device.type == "cuda":
+ if torch.cuda.is_bf16_supported():
+ dtype = torch.bfloat16
+ else:
+ dtype = torch.float16
+
+ try:
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM",
+ max_seq_length = 64,
+ load_in_4bit = False,
+ dtype = dtype,
+ )
+ except OSError as exc: # pragma: no cover - offline CI
+ pytest.skip(f"Requires access to tiny llama checkpoint: {exc}")
+
+ model.to(device)
+
+ dataset = Dataset.from_dict(
+ {
+ "text": [
+ "Hello world!",
+ "Short sample.",
+ "This is a slightly longer packed example to test batching.",
+ "Another response to include in the batch.",
+ ]
+ }
+ )
+
+ training_args = SFTConfig(
+ per_device_train_batch_size = 1,
+ per_device_eval_batch_size = 1,
+ gradient_accumulation_steps = 1,
+ dataset_text_field = "text",
+ max_length = 64,
+ logging_steps = 1,
+ max_steps = 1,
+ fp16 = device.type == "cuda" and not torch.cuda.is_bf16_supported(),
+ bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported(),
+ dataset_num_proc = 1,
+ output_dir = str(tmp_path),
+ )
+ configure_sample_packing(training_args)
+
+ trainer = SFTTrainer(
+ model = model,
+ processing_class = tokenizer,
+ train_dataset = dataset,
+ args = training_args,
+ )
+
+ enable_sample_packing(model, trainer)
+
+ dataloader = trainer.get_train_dataloader()
+ batch = next(iter(dataloader))
+
+ model_device = next(model.parameters()).device
+
+ for key, value in list(batch.items()):
+ if torch.is_tensor(value):
+ batch[key] = value.to(model_device)
+
+ from unsloth.models import llama as llama_mod
+
+ return model, batch, trainer, llama_mod
+
+
+def _trim_batch_to_total_tokens(data, total_tokens):
+ def _trim_tensor(t: torch.Tensor):
+ if t.ndim >= 2 and t.size(1) > total_tokens:
+ return t[:, :total_tokens].contiguous()
+ return t
+
+ trimmed = {}
+ for key, value in data.items():
+ if torch.is_tensor(value):
+ trimmed[key] = _trim_tensor(value)
+ else:
+ trimmed[key] = value
+ return trimmed
+
+
+def test_mask_packed_sequence_boundaries_marks_single_row():
+ shift_labels = torch.arange(6, dtype = torch.long).view(1, 6)
+ changed = mask_packed_sequence_boundaries(
+ shift_labels,
+ torch.tensor([2, 1, 3], dtype = torch.int32),
+ )
+ assert changed is True
+ flat = shift_labels.view(-1)
+ assert flat[1].item() == -100
+ assert flat[2].item() == -100
+ assert flat[5].item() == -100
+ assert flat[0].item() != -100
+
+
+def test_mask_packed_sequence_boundaries_across_multiple_rows():
+ shift_labels = torch.arange(10, dtype = torch.long).view(2, 5)
+ lengths = torch.tensor([3, 2, 4, 1], dtype = torch.int32)
+ changed = mask_packed_sequence_boundaries(shift_labels, lengths)
+ assert changed is True
+ flat = shift_labels.view(-1)
+ for idx in (2, 4, 8, 9):
+ assert flat[idx].item() == -100
+ assert torch.any(flat != -100)
+
+
+def test_configure_sample_packing():
+ config = SimpleNamespace()
+ configure_sample_packing(config)
+
+ assert config.packing is True
+ assert config.padding_free is True
+ assert config.remove_unused_columns is False
+
+
+class _DummyChild(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.max_seq_length = 8
+
+
+class _DummyModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.max_seq_length = 16
+ self.child = _DummyChild()
+ self.config = SimpleNamespace(_attn_implementation = "sdpa")
+ self.generation_config = SimpleNamespace(attn_implementation = "sdpa")
+
+
+class _DummyTrainer:
+ def __init__(self):
+ self.args = SimpleNamespace(remove_unused_columns = True)
+ self.data_collator = DataCollatorForLanguageModeling(
+ pad_token_id = 0,
+ completion_only_loss = False,
+ padding_free = True,
+ return_position_ids = False,
+ return_tensors = "pt",
+ )
+
+
+def test_enable_sample_packing():
+ model = _DummyModel()
+ trainer = _DummyTrainer()
+
+ enable_sample_packing(model, trainer)
+
+ # model hierarchy should now allow packed overlength inputs
+ assert getattr(model, "_unsloth_allow_packed_overlength") is True
+ assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
+
+ collator = trainer.data_collator
+ assert collator.return_position_ids is True
+ assert getattr(collator, "_unsloth_packing_wrapped") is True
+
+ examples = [
+ {
+ "input_ids": [0, 1, 2],
+ "labels": [0, 1, 2],
+ "seq_lengths": [2, 1],
+ },
+ {
+ "input_ids": [3, 4, 5],
+ "labels": [3, 4, 5],
+ "seq_lengths": [3],
+ },
+ ]
+ batch = collator.torch_call(examples)
+
+ # packed lengths are aggregated into a single tensor
+ assert "packed_seq_lengths" in batch
+ assert torch.equal(
+ batch["packed_seq_lengths"],
+ torch.tensor([2, 1, 3], dtype = torch.int32),
+ )
+
+ assert batch["input_ids"].shape == (1, 6)
+ expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long)
+ assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions)
+
+
+def test_enable_sample_packing_trl_collator(tmp_path):
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model, _, trainer, _ = _build_packed_training_setup(tmp_path, device)
+
+ enable_sample_packing(model, trainer)
+
+ examples = [
+ {
+ "input_ids": [0, 1, 2],
+ "labels": [0, 1, 2],
+ "seq_lengths": [2, 1],
+ },
+ {
+ "input_ids": [3, 4, 5],
+ "labels": [3, 4, 5],
+ "seq_lengths": [3],
+ },
+ ]
+
+ batch = trainer.data_collator.torch_call(examples)
+
+ assert batch["input_ids"].shape == (1, 6)
+ assert torch.equal(
+ batch["packed_seq_lengths"],
+ torch.tensor([2, 1, 3], dtype = torch.int32),
+ )
+
+ expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long)
+ assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions)
+
+ if hasattr(trainer, "accelerator"):
+ trainer.accelerator.free_memory()
+
+
+def test_packing_sdpa(tmp_path):
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device)
+
+ assert "packed_seq_lengths" in batch
+ assert "attention_mask" not in batch
+ assert batch["packed_seq_lengths"].dtype == torch.int32
+
+ total_tokens = batch["input_ids"].size(-1)
+ assert int(batch["packed_seq_lengths"].sum().item()) == total_tokens
+
+ packed_tokens = int(batch["packed_seq_lengths"].sum().item())
+ assert "position_ids" in batch
+ flat_positions = batch["position_ids"].reshape(-1)[:packed_tokens]
+ expected_positions = torch.cat(
+ [
+ torch.arange(length, dtype = torch.long)
+ for length in batch["packed_seq_lengths"].tolist()
+ ]
+ )
+ assert torch.equal(flat_positions.cpu(), expected_positions)
+ inputs = _trim_batch_to_total_tokens(batch, packed_tokens)
+
+ seq_info = llama_mod.get_packed_info_from_kwargs(
+ {"packed_seq_lengths": batch["packed_seq_lengths"]},
+ inputs["input_ids"].shape[0] * inputs["input_ids"].shape[1],
+ inputs["input_ids"].device,
+ )
+ assert seq_info is not None
+
+ original_mask = attention_dispatch_utils.build_sdpa_packed_attention_mask
+ mask_calls = []
+ captured_loss_labels = {}
+
+ def _capture_mask(seq_info, dtype, device, *, sliding_window = None):
+ mask_calls.append(tuple(seq_info[0].tolist()))
+ return original_mask(
+ seq_info,
+ dtype = dtype,
+ device = device,
+ sliding_window = sliding_window,
+ )
+
+ def _capture_loss(*, logits, labels, **loss_kwargs):
+ captured_loss_labels["labels"] = labels.detach().to("cpu")
+ return torch.zeros((), device = logits.device, dtype = logits.dtype)
+
+ with ExitStack() as stack:
+ stack.enter_context(
+ patch.object(attention_dispatch_utils, "HAS_FLASH_ATTENTION", False)
+ )
+ stack.enter_context(
+ patch.object(attention_dispatch_utils, "HAS_XFORMERS", False)
+ )
+ stack.enter_context(
+ patch.object(
+ attention_dispatch_utils,
+ "build_sdpa_packed_attention_mask",
+ side_effect = _capture_mask,
+ )
+ )
+ stack.enter_context(
+ patch.object(
+ llama_mod,
+ "fast_cross_entropy_loss",
+ side_effect = _capture_loss,
+ )
+ )
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ assert mask_calls, "SDPA packed mask was not constructed"
+ assert outputs.loss is not None
+ assert "labels" in captured_loss_labels
+ flat_loss_labels = captured_loss_labels["labels"].reshape(-1)
+ boundaries = (
+ torch.cumsum(
+ batch["packed_seq_lengths"].to(device = "cpu", dtype = torch.long), dim = 0
+ )
+ - 1
+ )
+ for idx in boundaries.tolist():
+ assert flat_loss_labels[idx].item() == -100
+ assert torch.any(flat_loss_labels != -100)
+
+ if hasattr(trainer, "accelerator"):
+ trainer.accelerator.free_memory()
diff --git a/unsloth-cli.py b/unsloth-cli.py
index fb6e39266..358e61fae 100644
--- a/unsloth-cli.py
+++ b/unsloth-cli.py
@@ -34,29 +34,30 @@
def run(args):
- import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers.utils import strtobool
from trl import SFTTrainer, SFTConfig
- from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
+ from unsloth.utils import configure_sample_packing, enable_sample_packing
+ from unsloth.models.loader_utils import prepare_device_map
import logging
logging.getLogger("hf-to-gguf").setLevel(logging.WARNING)
# Load model and tokenizer
+ device_map, distributed = prepare_device_map()
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = args.model_name,
max_seq_length = args.max_seq_length,
dtype = args.dtype,
load_in_4bit = args.load_in_4bit,
+ device_map = device_map,
)
# Configure PEFT model
model = FastLanguageModel.get_peft_model(
model,
- r = args.r,
target_modules = [
"q_proj",
"k_proj",
@@ -112,6 +113,7 @@ def formatting_prompts_func(examples):
# Configure training arguments
training_args = SFTConfig(
per_device_train_batch_size = args.per_device_train_batch_size,
+ per_device_eval_batch_size = args.per_device_eval_batch_size,
gradient_accumulation_steps = args.gradient_accumulation_steps,
warmup_steps = args.warmup_steps,
max_steps = args.max_steps,
@@ -127,9 +129,15 @@ def formatting_prompts_func(examples):
report_to = args.report_to,
max_length = args.max_seq_length,
dataset_num_proc = 2,
- packing = False,
+ ddp_find_unused_parameters = False if distributed else None,
)
+ if args.sample_packing:
+ print(
+ f"Unsloth Packing: Sample packing enabled (max_length={args.max_seq_length})."
+ )
+ configure_sample_packing(training_args)
+
# Initialize trainer
trainer = SFTTrainer(
model = model,
@@ -138,8 +146,10 @@ def formatting_prompts_func(examples):
args = training_args,
)
- # Train model
- trainer_stats = trainer.train()
+ if args.sample_packing:
+ enable_sample_packing(model, trainer)
+
+ trainer.train()
# Save model
if args.save_model:
@@ -164,13 +174,15 @@ def formatting_prompts_func(examples):
else:
print(f"Saving model with quantization method: {args.quantization}")
model.save_pretrained_gguf(
- args.save_path, tokenizer, quantization_method = args.quantization
+ args.save_path,
+ tokenizer,
+ quantization_method = args.quantization,
)
if args.push_model:
model.push_to_hub_gguf(
hub_path = args.hub_path,
hub_token = args.hub_token,
- quantization_method = quantization_method,
+ quantization_method = args.quantization,
)
else:
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
@@ -181,7 +193,6 @@ def formatting_prompts_func(examples):
if __name__ == "__main__":
- # Define argument parser
parser = argparse.ArgumentParser(
description = "🦥 Fine-tune your llm faster using unsloth!"
)
@@ -218,7 +229,8 @@ def formatting_prompts_func(examples):
)
lora_group = parser.add_argument_group(
- "🧠LoRA Options", "These options are used to configure the LoRA model."
+ "🧠LoRA Options",
+ "These options are used to configure the LoRA model.",
)
lora_group.add_argument(
"--r",
@@ -239,7 +251,10 @@ def formatting_prompts_func(examples):
help = "LoRA dropout rate, default is 0.0 which is optimized.",
)
lora_group.add_argument(
- "--bias", type = str, default = "none", help = "Bias setting for LoRA"
+ "--bias",
+ type = str,
+ default = "none",
+ help = "Bias setting for LoRA",
)
lora_group.add_argument(
"--use_gradient_checkpointing",
@@ -254,10 +269,15 @@ def formatting_prompts_func(examples):
help = "Random state for reproducibility, default is 3407.",
)
lora_group.add_argument(
- "--use_rslora", action = "store_true", help = "Use rank stabilized LoRA"
+ "--use_rslora",
+ action = "store_true",
+ help = "Use rank stabilized LoRA",
)
lora_group.add_argument(
- "--loftq_config", type = str, default = None, help = "Configuration for LoftQ"
+ "--loftq_config",
+ type = str,
+ default = None,
+ help = "Configuration for LoftQ",
)
training_group = parser.add_argument_group("🎓 Training Options")
@@ -267,6 +287,12 @@ def formatting_prompts_func(examples):
default = 2,
help = "Batch size per device during training, default is 2.",
)
+ training_group.add_argument(
+ "--per_device_eval_batch_size",
+ type = int,
+ default = 4,
+ help = "Batch size per device during evaluation, default is 4.",
+ )
training_group.add_argument(
"--gradient_accumulation_steps",
type = int,
@@ -280,7 +306,10 @@ def formatting_prompts_func(examples):
help = "Number of warmup steps, default is 5.",
)
training_group.add_argument(
- "--max_steps", type = int, default = 400, help = "Maximum number of training steps."
+ "--max_steps",
+ type = int,
+ default = 400,
+ help = "Maximum number of training steps.",
)
training_group.add_argument(
"--learning_rate",
@@ -289,7 +318,10 @@ def formatting_prompts_func(examples):
help = "Learning rate, default is 2e-4.",
)
training_group.add_argument(
- "--optim", type = str, default = "adamw_8bit", help = "Optimizer type."
+ "--optim",
+ type = str,
+ default = "adamw_8bit",
+ help = "Optimizer type.",
)
training_group.add_argument(
"--weight_decay",
@@ -309,8 +341,12 @@ def formatting_prompts_func(examples):
default = 3407,
help = "Seed for reproducibility, default is 3407.",
)
+ training_group.add_argument(
+ "--sample_packing",
+ action = "store_true",
+ help = "Enable cross-example packing using TRL's dataset bin packing.",
+ )
- # Report/Logging arguments
report_group = parser.add_argument_group("📊 Report Options")
report_group.add_argument(
"--report_to",
@@ -331,19 +367,31 @@ def formatting_prompts_func(examples):
"all",
"none",
],
- help = "The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.",
+ help = (
+ "The list of integrations to report the results and logs to. Supported platforms are:\n\t\t "
+ "'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', "
+ "'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations "
+ "installed, 'none' for no integrations."
+ ),
)
report_group.add_argument(
- "--logging_steps", type = int, default = 1, help = "Logging steps, default is 1"
+ "--logging_steps",
+ type = int,
+ default = 1,
+ help = "Logging steps, default is 1",
)
- # Saving and pushing arguments
save_group = parser.add_argument_group("💾 Save Model Options")
save_group.add_argument(
- "--output_dir", type = str, default = "outputs", help = "Output directory"
+ "--output_dir",
+ type = str,
+ default = "outputs",
+ help = "Output directory",
)
save_group.add_argument(
- "--save_model", action = "store_true", help = "Save the model after training"
+ "--save_model",
+ action = "store_true",
+ help = "Save the model after training",
)
save_group.add_argument(
"--save_method",
@@ -358,14 +406,20 @@ def formatting_prompts_func(examples):
help = "Convert the model to GGUF after training",
)
save_group.add_argument(
- "--save_path", type = str, default = "model", help = "Path to save the model"
+ "--save_path",
+ type = str,
+ default = "model",
+ help = "Path to save the model",
)
save_group.add_argument(
"--quantization",
type = str,
default = "q8_0",
nargs = "+",
- help = "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ",
+ help = (
+ "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), "
+ "Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf"
+ ),
)
push_group = parser.add_argument_group("🚀 Push Model Options")
@@ -386,7 +440,9 @@ def formatting_prompts_func(examples):
help = "Path on Hugging Face hub to push the model",
)
push_group.add_argument(
- "--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub"
+ "--hub_token",
+ type = str,
+ help = "Token for pushing the model to Hugging Face hub",
)
args = parser.parse_args()
diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py
index 8eecec10c..a33f94016 100644
--- a/unsloth/kernels/rope_embedding.py
+++ b/unsloth/kernels/rope_embedding.py
@@ -15,8 +15,96 @@
import triton
import triton.language as tl
import torch
+from ..device_type import DEVICE_COUNT
from .utils import calculate_settings, torch_gpu_device, torch_device_stream
+
+@triton.heuristics(
+ {
+ "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
+ "HAS_ROPE_INDICES": lambda args: bool(args["HAS_ROPE_INDICES"]),
+ }
+)
+@triton.jit
+def _rope_embedding_QK(
+ Q,
+ Q_batch_stride,
+ Q_head_stride,
+ Q_seq_stride,
+ K,
+ K_batch_stride,
+ K_head_stride,
+ K_seq_stride,
+ cos,
+ cos_row_stride,
+ sin,
+ sin_row_stride,
+ rope_embedding_indices,
+ seqlen,
+ head_dim: tl.constexpr,
+ n_heads_K: tl.constexpr,
+ BACKWARD_PASS: tl.constexpr,
+ HAS_ROPE_INDICES: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ row_position = tl.program_id(0)
+ head_position = tl.program_id(1)
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ half_head_dim = head_dim // 2
+ mask = col_offsets < half_head_dim
+
+ if HAS_ROPE_INDICES:
+ rot_position = tl.load(
+ rope_embedding_indices + row_position,
+ eviction_policy = "evict_first",
+ ).to(tl.int32)
+ else:
+ rot_position = row_position % seqlen
+
+ cos_ptr = cos + rot_position * cos_row_stride
+ sin_ptr = sin + rot_position * sin_row_stride
+ sin1 = tl.load(
+ sin_ptr + col_offsets,
+ mask = mask,
+ other = 0,
+ eviction_policy = "evict_first",
+ )
+ cos1 = tl.load(
+ cos_ptr + col_offsets,
+ mask = mask,
+ other = 0,
+ eviction_policy = "evict_first",
+ )
+ if BACKWARD_PASS:
+ sin1 = -sin1
+
+ batch_id = row_position // seqlen
+ seq_index = row_position - batch_id * seqlen
+
+ q_ptr = (
+ Q
+ + batch_id * Q_batch_stride
+ + head_position * Q_head_stride
+ + seq_index * Q_seq_stride
+ )
+ q0 = tl.load(q_ptr + col_offsets, mask = mask, other = 0)
+ q1 = tl.load(q_ptr + half_head_dim + col_offsets, mask = mask, other = 0)
+ tl.store(q_ptr + col_offsets, q0 * cos1 - q1 * sin1, mask = mask)
+ tl.store(q_ptr + half_head_dim + col_offsets, q1 * cos1 + q0 * sin1, mask = mask)
+
+ if head_position < n_heads_K:
+ k_ptr = (
+ K
+ + batch_id * K_batch_stride
+ + head_position * K_head_stride
+ + seq_index * K_seq_stride
+ )
+ k0 = tl.load(k_ptr + col_offsets, mask = mask, other = 0)
+ k1 = tl.load(k_ptr + half_head_dim + col_offsets, mask = mask, other = 0)
+ tl.store(k_ptr + col_offsets, k0 * cos1 - k1 * sin1, mask = mask)
+ tl.store(k_ptr + half_head_dim + col_offsets, k1 * cos1 + k0 * sin1, mask = mask)
+
+
ROPE_GROUP_SIZE: int = 4
@@ -102,7 +190,7 @@ def forward(ctx, Q, cos, sin):
n_heads: int
head_dim: int
batch, seq_len, n_heads, head_dim = Q.shape
- Q = Q.view(batch * seq_len, n_heads * head_dim)
+ Q = Q.reshape(batch * seq_len, n_heads * head_dim)
n_rows: int
n_cols: int
n_rows, n_cols = Q.shape
@@ -143,7 +231,7 @@ def forward(ctx, Q, cos, sin):
ctx.n_groups = n_groups
ctx.cos = cos
ctx.sin = sin
- return Q.view(batch, seq_len, n_heads, head_dim)
+ return Q.reshape(batch, seq_len, n_heads, head_dim)
@staticmethod
def backward(ctx, dY):
@@ -153,7 +241,6 @@ def backward(ctx, dY):
head_dim: int
batch, seq_len, n_heads, head_dim = dY.shape
dY = dY.reshape(batch * seq_len, n_heads * head_dim)
- # Must be reshape not view
n_rows: int
n_cols: int
n_rows, n_cols = dY.shape
@@ -181,7 +268,7 @@ def backward(ctx, dY):
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
- dY = dY.view(batch, seq_len, n_heads, head_dim)
+ dY = dY.reshape(batch, seq_len, n_heads, head_dim)
return (
dY,
None,
@@ -191,12 +278,142 @@ def backward(ctx, dY):
# [TODO] Unsure why RoPE Embedding is not torch.compiling properly
@torch.compiler.disable
-def fast_rope_embedding(Q, K, cos, sin):
- Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
- K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
- # synchronize before cat to avoid race condition
- torch_device_stream(Q.device).synchronize()
- return Q, K
+def fast_rope_embedding(Q, K, cos, sin, rope_embedding_indices = None):
+ if rope_embedding_indices is not None:
+ Q_out, K_out = Fast_RoPE_Embedding_QK.apply(
+ Q, K, cos, sin, rope_embedding_indices
+ )
+ else:
+ Q_out = Fast_RoPE_Embedding.apply(
+ Q.transpose(1, 2).contiguous(), cos, sin
+ ).transpose(1, 2)
+ K_out = Fast_RoPE_Embedding.apply(
+ K.transpose(1, 2).contiguous(), cos, sin
+ ).transpose(1, 2)
+ if DEVICE_COUNT > 1:
+ torch_device_stream(Q.device).synchronize()
+ return Q_out, K_out
+
+
+class Fast_RoPE_Embedding_QK(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, Q, K, cos, sin, rope_indices):
+ has_indices = rope_indices is not None
+ cos, sin = cos.squeeze(), sin.squeeze()
+
+ batch, n_heads_Q, seq_len, head_dim = Q.shape
+ _, n_heads_K, _, _ = K.shape
+
+ Q_out = Q.clone()
+ K_out = K.clone()
+
+ if has_indices:
+ rope_ptr = rope_indices.reshape(-1).to(dtype = torch.int32, device = Q.device)
+ else:
+ rope_ptr = cos.new_empty(1, dtype = torch.int32)
+
+ BLOCK_SIZE, num_warps = calculate_settings(head_dim)
+
+ Q_batch_stride, Q_head_stride, Q_seq_stride = (
+ Q_out.stride(0),
+ Q_out.stride(1),
+ Q_out.stride(2),
+ )
+ K_batch_stride, K_head_stride, K_seq_stride = (
+ K_out.stride(0),
+ K_out.stride(1),
+ K_out.stride(2),
+ )
+
+ with torch_gpu_device(Q.device):
+ _rope_embedding_QK[(batch * seq_len, n_heads_Q)](
+ Q_out,
+ Q_batch_stride,
+ Q_head_stride,
+ Q_seq_stride,
+ K_out,
+ K_batch_stride,
+ K_head_stride,
+ K_seq_stride,
+ cos,
+ cos.stride(0),
+ sin,
+ sin.stride(0),
+ rope_ptr,
+ seq_len,
+ head_dim = head_dim,
+ n_heads_K = n_heads_K,
+ BACKWARD_PASS = False,
+ HAS_ROPE_INDICES = has_indices,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = num_warps,
+ )
+
+ ctx.block_size = BLOCK_SIZE
+ ctx.num_warps = num_warps
+ ctx.has_indices = has_indices
+ ctx.cos = cos
+ ctx.sin = sin
+ ctx.rope_indices = rope_ptr if has_indices else None
+ ctx.seq_len = seq_len
+ ctx.n_heads_Q = n_heads_Q
+ ctx.n_heads_K = n_heads_K
+
+ return (
+ Q_out,
+ K_out,
+ )
+
+ @staticmethod
+ def backward(ctx, dQ, dK):
+ batch, n_heads_Q, seq_len, head_dim = dQ.shape
+ _, n_heads_K, _, _ = dK.shape
+
+ dQ_out = dQ.clone()
+ dK_out = dK.clone()
+
+ rope_ptr = (
+ ctx.rope_indices
+ if ctx.has_indices
+ else ctx.cos.new_empty(1, dtype = torch.int32)
+ )
+
+ Q_batch_stride, Q_head_stride, Q_seq_stride = (
+ dQ_out.stride(0),
+ dQ_out.stride(1),
+ dQ_out.stride(2),
+ )
+ K_batch_stride, K_head_stride, K_seq_stride = (
+ dK_out.stride(0),
+ dK_out.stride(1),
+ dK_out.stride(2),
+ )
+
+ with torch_gpu_device(dQ.device):
+ _rope_embedding_QK[(batch * ctx.seq_len, ctx.n_heads_Q)](
+ dQ_out,
+ Q_batch_stride,
+ Q_head_stride,
+ Q_seq_stride,
+ dK_out,
+ K_batch_stride,
+ K_head_stride,
+ K_seq_stride,
+ ctx.cos,
+ ctx.cos.stride(0),
+ ctx.sin,
+ ctx.sin.stride(0),
+ rope_ptr,
+ ctx.seq_len,
+ head_dim = head_dim,
+ n_heads_K = ctx.n_heads_K,
+ BACKWARD_PASS = True,
+ HAS_ROPE_INDICES = ctx.has_indices,
+ BLOCK_SIZE = ctx.block_size,
+ num_warps = ctx.num_warps,
+ )
+
+ return (dQ_out, dK_out, None, None, None)
class Slow_RoPE_Embedding(torch.autograd.Function):
@@ -206,8 +423,8 @@ def forward(ctx, Q, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
+ sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
# Q * cos + rotate_half(Q) * sin
half = Q.shape[-1] // 2
diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py
index 54188305a..1d8858bd8 100644
--- a/unsloth/models/cohere.py
+++ b/unsloth/models/cohere.py
@@ -16,6 +16,13 @@
from ._utils import __version__
from unsloth_zoo.hf_utils import dtype_from_config
from unsloth_zoo.utils import _get_dtype
+from ..utils.packing import get_packed_info_from_kwargs
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+)
try:
from transformers.models.cohere.modeling_cohere import (
@@ -104,6 +111,7 @@ def CohereAttention_fast_forward(
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device)
if self.use_qk_norm:
Q = fast_layernorm_compiled(self.q_norm, Q)
K = fast_layernorm_compiled(self.k_norm, K)
@@ -112,12 +120,17 @@ def CohereAttention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = position_embeddings
- if position_ids is None:
- Q, K = fast_rope_embedding(Q, K, cos, sin)
+ # Extend RoPE dynamically to fit in VRAM
+ if position_embeddings:
+ cos, sin = position_embeddings
else:
- cos, sin = cos[position_ids], sin[position_ids]
- Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)
+
+ rope_position_ids = (
+ position_ids if position_ids is not None else kwargs.get("position_ids")
+ )
+ # Useful for LongRoPE
+ Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -125,54 +138,33 @@ def CohereAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
- # Xformers memory efficient attention
- # Also has Flash Attention v2 dispatching
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
-
- # Group query attention
- if n_groups != 1:
- K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- if hidden_states.requires_grad:
- K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
- else:
- Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
- A = xformers_attention(Q, K, V, attn_bias = causal_mask)
- A = A.view(bsz, q_len, n_heads, head_dim)
-
- elif HAS_FLASH_ATTENTION and attention_mask is None:
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- A = flash_attn_func(Q, K, V, causal = True)
- else:
- # Grouped query attention
- if n_groups != 1:
- K = K[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
- )
- V = V[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
- )
- K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
- pass
- # Must be contiguous or else results are False!
- # https://github.com/pytorch/pytorch/issues/112577
- Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(
- Q, K, V, attn_mask = attention_mask, is_causal = False
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2).contiguous()
+ use_varlen = seq_info is not None and past_key_value is None
+ backend = select_attention_backend(use_varlen)
+ attention_config = AttentionConfig(
+ backend = backend,
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {"causal": True},
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "causal": True,
+ "softmax_scale": getattr(self, "softmax_scale", None),
+ },
+ )
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ )
+
+ A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
+
attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
@@ -215,6 +207,7 @@ def CohereDecoderLayer_fast_forward(
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
+ **kwargs,
)
# Fully Connected
@@ -234,6 +227,7 @@ def CohereDecoderLayer_fast_forward(
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
+ **kwargs,
)
# Fully Connected
diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py
index 70223a3c2..384edd454 100644
--- a/unsloth/models/falcon_h1.py
+++ b/unsloth/models/falcon_h1.py
@@ -17,6 +17,14 @@
from ._utils import __version__
from unsloth_zoo.utils import Version, _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
+from ..utils.packing import get_packed_info_from_kwargs
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+ SDPA,
+)
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -98,13 +106,10 @@ def FalconH1Attention_fast_forward(
assert n_kv_heads * n_groups == n_heads
Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(
- bsz, q_len, n_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- K = K.view(
- bsz, q_len, n_kv_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
+ Q = Q.view(bsz, q_len, n_heads, head_dim)
+ K = K.view(bsz, q_len, n_kv_heads, head_dim)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, hidden_states.device)
# Falcon H1 multiplies key states by a multiplier
K = K * self.config.key_multiplier
@@ -116,20 +121,19 @@ def FalconH1Attention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
+ # Extend RoPE dynamically to fit in VRAM
if position_embeddings:
cos, sin = position_embeddings
else:
- # Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
- device_index = Q.device.index
+ cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)
- if position_ids is None:
- # Useful for LongRoPE
- cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
- else:
- cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
- Q, K = fast_rope_embedding(Q, K, cos, sin)
+ rope_position_ids = (
+ position_ids if position_ids is not None else kwargs.get("position_ids")
+ )
+ # Useful for LongRoPE
+ Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -137,54 +141,45 @@ def FalconH1Attention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if not HAS_FLASH_ATTENTION and attention_mask is None:
- # Xformers memory efficient attention
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- K_M = V_M = bsz * kv_seq_len
- Q_M = bsz * q_len
-
- # Group query attention
- K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- if hidden_states.requires_grad:
- K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
- else:
- # Xformers does support the forward pass though
- Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-
- A = xformers_attention(Q, K, V, attn_bias = causal_mask)
- A = A.view(bsz, q_len, n_heads, head_dim)
-
- elif HAS_FLASH_ATTENTION and attention_mask is None:
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- sw = kv_seq_len
- window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
- A = flash_attn_func(Q, K, V, causal = True, window_size = window)
- else:
- # Grouped query attention
- # if n_groups != 1:
- K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
- # pass
- # Must be contiguous or else results are False!
- # https://github.com/pytorch/pytorch/issues/112577
- Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(
- Q, K, V, attn_mask = attention_mask, is_causal = False
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2).contiguous()
+ window = (-1, -1)
+ use_varlen = (
+ attention_mask is None
+ and seq_info is not None
+ and past_key_value is None
+ and window == (-1, -1)
+ )
+
+ backend = (
+ SDPA if attention_mask is not None else select_attention_backend(use_varlen)
+ )
+ attention_config = AttentionConfig(
+ backend = backend,
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {
+ "causal": True,
+ "window_size": (kv_seq_len, kv_seq_len),
+ },
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "softmax_scale": None,
+ "causal": True,
+ },
+ sdpa_kwargs = {} if attention_mask is None else {"attn_mask": attention_mask},
+ )
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ )
+
+ A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)
@@ -442,6 +437,7 @@ def FalconH1DecoderLayer_fast_forward(
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
@@ -486,6 +482,7 @@ def FalconH1DecoderLayer_fast_forward(
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py
index cf5bdd8eb..bb92ef6e6 100644
--- a/unsloth/models/gemma.py
+++ b/unsloth/models/gemma.py
@@ -16,6 +16,11 @@
from ._utils import __version__
from unsloth_zoo.utils import _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
+from ..utils.packing import (
+ build_sdpa_packed_attention_mask,
+ build_xformers_block_causal_mask,
+ get_packed_info_from_kwargs,
+)
import math
try:
@@ -110,6 +115,7 @@ def GemmaDecoderLayer_fast_forward(
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
+ **kwargs,
)
hidden_states += residual
@@ -134,6 +140,7 @@ def GemmaDecoderLayer_fast_forward(
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
+ **kwargs,
)
hidden_states = residual + hidden_states
diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py
index de70f7873..98928b122 100644
--- a/unsloth/models/gemma2.py
+++ b/unsloth/models/gemma2.py
@@ -16,6 +16,14 @@
from ._utils import __version__
from unsloth_zoo.utils import _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
+from ..utils.packing import get_packed_info_from_kwargs
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+ SDPA,
+)
from .gemma import (
GemmaFixedRotaryEmbedding,
GemmaFixedLinearScalingRotaryEmbedding,
@@ -98,19 +106,25 @@ def Gemma2Attention_fast_forward(
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
device_index = Q.device.index
- if position_ids is None:
- cos = self.rotary_emb.multi_gpu_cos_cached[device_index]
- sin = self.rotary_emb.multi_gpu_sin_cached[device_index]
- Q, K = fast_rope_embedding(Q, K, cos, sin)
+ cos = self.rotary_emb.multi_gpu_cos_cached[device_index]
+ sin = self.rotary_emb.multi_gpu_sin_cached[device_index]
+
+ rope_position_ids = (
+ position_ids if position_ids is not None else kwargs.get("position_ids")
+ )
+ if rope_position_ids is not None:
+ # Useful for LongRoPE
+ cos_var, sin_var = self.rotary_emb.get_cached(kv_seq_len, device_index)
+ Q, K = fast_rope_embedding(Q, K, cos_var, sin_var, rope_position_ids)
else:
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, device_index)
- Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -118,32 +132,68 @@ def Gemma2Attention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Only enable if the attention_mask is True
- has_sliding_window = type(causal_mask) is bool and causal_mask is True
- if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
+ use_sliding_window = kwargs.get("use_sliding_window")
+ has_sliding_window = (
+ use_sliding_window
+ if use_sliding_window is not None
+ else isinstance(causal_mask, bool) and causal_mask is True
+ )
+
+ use_flash = HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None
+
+ if use_flash:
window = (-1, -1)
+ sliding_window = getattr(self.config, "sliding_window", None)
if has_sliding_window:
- sw = getattr(self.config, "sliding_window", None)
- sw = kv_seq_len if (sw is None or sw == "null") else sw
- window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
+ sliding_window = (
+ sliding_window if sliding_window is not None else kv_seq_len
+ )
+ window = (
+ (-1, -1)
+ if kv_seq_len <= sliding_window
+ else (sliding_window, sliding_window)
+ )
- # FA uses 1 / sqrt for softmax_scale!
if not hasattr(self, "_flash_attention_softmax_scale"):
self._flash_attention_softmax_scale = 1.0 / (
self.config.query_pre_attn_scalar**0.5
)
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- A = flash_attn_func(
- Q,
- K,
- V,
- causal = True,
- softcap = self.config.attn_logit_softcapping,
- softmax_scale = self._flash_attention_softmax_scale,
- window_size = window,
+ use_varlen = seq_info is not None and past_key_value is None
+
+ attention_config = AttentionConfig(
+ backend = select_attention_backend(use_varlen),
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {
+ "causal": True,
+ "softcap": self.config.attn_logit_softcapping,
+ "softmax_scale": self._flash_attention_softmax_scale,
+ "window_size": window,
+ },
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "softmax_scale": self._flash_attention_softmax_scale,
+ "causal": True,
+ "softcap": self.config.attn_logit_softcapping,
+ "window_size": window,
+ },
+ )
+
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ sliding_window = sliding_window,
)
+
+ A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
A = A.reshape(bsz, q_len, n_heads * head_dim)
else:
fx = (
@@ -192,6 +242,7 @@ def Gemma2DecoderLayer_fast_forward(
use_cache = use_cache,
padding_mask = padding_mask,
_flag_for_generation = self._flag_for_generation,
+ **kwargs,
)
hidden_states = fast_rms_layernorm_inference_gemma(
self.post_attention_layernorm, hidden_states, out_weight
@@ -222,6 +273,7 @@ def Gemma2DecoderLayer_fast_forward(
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
+ **kwargs,
)
hidden_states = fast_rms_layernorm(
self.post_attention_layernorm, hidden_states, gemma = True
diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py
index 0a816e399..e513927a6 100644
--- a/unsloth/models/granite.py
+++ b/unsloth/models/granite.py
@@ -17,6 +17,14 @@
from ._utils import __version__
from unsloth_zoo.utils import _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
+from ..utils.packing import get_packed_info_from_kwargs
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+ SDPA,
+)
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -96,6 +104,7 @@ def GraniteAttention_fast_forward(
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
@@ -103,10 +112,14 @@ def GraniteAttention_fast_forward(
assert position_embeddings is not None
cos, sin = position_embeddings
- if position_ids is None:
- Q, K = fast_rope_embedding(Q, K, cos, sin)
+ rope_position_ids = (
+ position_ids if position_ids is not None else kwargs.get("position_ids")
+ )
+ if rope_position_ids is not None:
+ # Useful for LongRoPE
+ Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
else:
- Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -114,69 +127,59 @@ def GraniteAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
- # Xformers memory efficient attention
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- K_M = V_M = bsz * kv_seq_len
- Q_M = bsz * q_len
-
- # Group query attention
- K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- if hidden_states.requires_grad:
- K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
- else:
- # Xformers does support the forward pass though
- Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-
- A = xformers_attention(
- Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p
- )
- A = A.view(bsz, q_len, n_heads, head_dim)
-
- elif HAS_FLASH_ATTENTION and attention_mask is None:
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- window = (kv_seq_len, kv_seq_len)
- A = flash_attn_func(
- Q,
- K,
- V,
- causal = True,
- window_size = window,
- softmax_scale = self.scaling,
- dropout_p = dropout_p,
- )
- else:
- # Grouped query attention
- # if n_groups != 1:
- K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
- # pass
- # Must be contiguous or else results are False!
- # https://github.com/pytorch/pytorch/issues/112577
- Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(
- Q,
- K,
- V,
- attn_mask = attention_mask,
- scale = self.scaling,
- is_causal = False,
- dropout_p = dropout_p,
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2).contiguous()
+ use_varlen = (
+ attention_mask is None and seq_info is not None and past_key_value is None
+ )
+
+ backend = (
+ SDPA if attention_mask is not None else select_attention_backend(use_varlen)
+ )
+
+ window = (kv_seq_len, kv_seq_len)
+ softmax_scale = getattr(self, "scaling", None)
+ attention_config = AttentionConfig(
+ backend = backend,
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {
+ "causal": True,
+ "softmax_scale": softmax_scale,
+ "dropout_p": dropout_p,
+ "window_size": window,
+ },
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "softmax_scale": softmax_scale,
+ "causal": True,
+ },
+ sdpa_kwargs = {
+ k: v
+ for k, v in {
+ "attn_mask": attention_mask,
+ "scale": softmax_scale,
+ "dropout_p": dropout_p,
+ }.items()
+ if v is not None
+ },
+ xformers_kwargs = {
+ "scale": softmax_scale,
+ "p": dropout_p,
+ },
+ )
+
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ )
+
+ A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)
@@ -222,6 +225,7 @@ def GraniteDecoderLayer_fast_forward(
padding_mask = padding_mask,
position_embeddings = position_embeddings,
_flag_for_generation = self._flag_for_generation,
+ **kwargs,
)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
@@ -245,6 +249,7 @@ def GraniteDecoderLayer_fast_forward(
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py
index 3af9cba66..5b54526f8 100644
--- a/unsloth/models/llama.py
+++ b/unsloth/models/llama.py
@@ -16,12 +16,23 @@
import gc
import math
import functools
-from typing import Any, Dict, Optional, Tuple, List, Union
+from typing import Optional, Tuple, List, Union
+
from ._utils import *
from ._utils import patch_unsloth_smart_gradient_checkpointing
from ._utils import __version__, importlib_version
from ._utils import move_to_device
from ._utils import _prepare_model_for_qat
+from ..utils.packing import (
+ get_packed_info_from_kwargs,
+ mask_packed_sequence_boundaries,
+)
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+)
from torch.nn.functional import scaled_dot_product_attention
from transformers import __version__ as transformers_version
from unsloth_zoo.utils import Version, _get_dtype
@@ -58,9 +69,6 @@
)
from ..kernels import *
from ..tokenizer_utils import *
-
-if HAS_FLASH_ATTENTION:
- from flash_attn import flash_attn_func
from .vision import FastBaseModel
# Final patching code
@@ -561,6 +569,7 @@ def LlamaAttention_fast_forward(
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
@@ -569,23 +578,20 @@ def LlamaAttention_fast_forward(
if position_embeddings:
cos, sin = position_embeddings
else:
- # Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
-
- # if position_ids is None:
- # # Useful for LongRoPE
- # cos, sin = rotary_emb.get_cached(kv_seq_len, device = Q.device)
- # else:
- # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device)
cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)
+ rope_position_ids = position_ids
+ if rope_position_ids is None and seq_info is not None:
+ rope_position_ids = kwargs.get("position_ids")
+
# Q, K = (
# fast_rope_embedding(Q, K, cos, sin)
- # if position_ids is None
- # else inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ # if rope_position_ids is None
+ # else inplace_rope_embedding(Q, K, cos, sin, rope_position_ids)
# )
- Q, K = fast_rope_embedding(Q, K, cos, sin)
+ Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -593,76 +599,28 @@ def LlamaAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
- # Xformers memory efficient attention
- # Also has Flash Attention v2 dispatching
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
-
- # Group query attention
- if n_groups != 1:
- K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- if hidden_states.requires_grad:
- K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
- else:
- Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
- A = xformers_attention(Q, K, V, attn_bias = causal_mask)
- A = A.view(bsz, q_len, n_heads, head_dim)
-
- elif HAS_FLASH_ATTENTION and attention_mask is None:
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- A = flash_attn_func(Q, K, V, causal = True)
- else:
- # when qlen==vlen and attn_mask is None, we should use causal attention
- Q_len = Q.shape[-2]
- K_len = K.shape[-2]
- if attention_mask is None and Q_len == K_len:
- is_causal = True
- else:
- is_causal = False
- # Grouped query attention
- if SDPA_HAS_GQA:
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(
- Q,
- K,
- V,
- attn_mask = attention_mask,
- is_causal = is_causal,
- enable_gqa = n_groups != 1,
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2) # .contiguous()
- else:
- if n_groups != 1:
- K = K[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
- )
- V = V[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
- )
- K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
- pass
- # Must be contiguous or else results are False!
- # https://github.com/pytorch/pytorch/issues/112577
- Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(
- Q, K, V, attn_mask = attention_mask, is_causal = is_causal
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2).contiguous()
- pass
+ use_varlen = seq_info is not None and past_key_value is None
+ backend = select_attention_backend(use_varlen)
+ config = AttentionConfig(
+ backend = backend,
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {"causal": True},
+ flash_varlen_kwargs = {"dropout_p": 0.0, "causal": True},
+ )
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ )
+
+ A = run_attention(config = config, context = context, Q = Q, K = K, V = V)
attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
@@ -712,6 +670,7 @@ def LlamaDecoderLayer_fast_forward(
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
hidden_states += residual
@@ -735,6 +694,7 @@ def LlamaDecoderLayer_fast_forward(
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
hidden_states = residual + hidden_states
@@ -812,8 +772,11 @@ def LlamaModel_fast_forward(
seq_length_with_past = seq_length
- # Fix out of bounds tokenization
- if hasattr(self, "max_seq_length"):
+ # Fix out of bounds tokenization unless we were given packed metadata
+ allow_overlength = getattr(self, "_unsloth_allow_packed_overlength", False) or (
+ "packed_seq_lengths" in kwargs
+ )
+ if hasattr(self, "max_seq_length") and not allow_overlength:
if seq_length > self.max_seq_length:
shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
logger.warning_once(
@@ -1065,10 +1028,12 @@ def LlamaModel_fast_forward(
mask = causal_mask
if IS_GEMMA2:
- if idx % 2 == 0:
+ use_sliding_window = idx % 2 == 0
+ if use_sliding_window:
mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask
else:
mask = self.GA_mask if use_static_mask else dynamic_GA_mask
+ kwargs["use_sliding_window"] = use_sliding_window
if gradient_checkpointing and not isinstance(
decoder_layer, GradientCheckpointingLayer
@@ -1082,6 +1047,7 @@ def custom_forward(*inputs):
output_attentions,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
return custom_forward
@@ -1108,6 +1074,7 @@ def custom_forward(*inputs):
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
+ **kwargs,
)
hidden_states = layer_outputs[0]
@@ -1299,6 +1266,7 @@ def _CausalLM_fast_forward(
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
+ **kwargs,
)
else:
causal_mask = (
@@ -1331,6 +1299,7 @@ def _CausalLM_fast_forward(
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
+ **kwargs,
)
hidden_states = outputs[0]
@@ -1439,6 +1408,10 @@ def _CausalLM_fast_forward(
shift_labels = torch.empty_like(labels)
shift_labels[..., :-1] = labels[..., 1:]
shift_labels[..., -1] = -100
+ mask_packed_sequence_boundaries(
+ shift_labels,
+ kwargs.get("packed_seq_lengths"),
+ )
# shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
n_items = kwargs.get("num_items_in_batch", None)
if n_items is None:
@@ -1996,8 +1969,8 @@ def unsloth_fast_generate(
> self.config.max_position_embeddings
):
raise ValueError(
- f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n'
- 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.'
+ f"Unsloth: input length {kwargs['input_ids'].shape[-1]} + max_new_tokens {kwargs['max_new_tokens']} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n"
+ "You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`."
)
# Must patch accelerate for Xformers
@@ -2235,8 +2208,6 @@ def from_pretrained(
model_config.model_name = model_name
model_max_seq_length = model_config.max_position_embeddings
- verify_fp8_support_if_applicable(model_config)
-
# Check if RoPE Scaling is even allowed
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
IS_FALCON_H1 = model_config.model_type.startswith("falcon_h1")
diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py
index 631f614cb..d587d0016 100644
--- a/unsloth/models/loader_utils.py
+++ b/unsloth/models/loader_utils.py
@@ -13,6 +13,9 @@
# limitations under the License.
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
+from ..device_type import DEVICE_TYPE_TORCH
+import os
+import torch
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from packaging.version import Version
@@ -21,6 +24,9 @@
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
+LOCAL_RANK_KEYS = ("LOCAL_RANK", "RANK")
+WORLD_SIZE_KEYS = ("WORLD_SIZE",)
+
BAD_MAPPINGS = {
"unsloth/Qwen3-32B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-32B-bnb-4bit".lower(), # 32B dynamic quant is way too big
"unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B".lower(), # HF loads MoEs too slowly
@@ -30,6 +36,50 @@
}
+def _get_env_int(keys):
+ for key in keys:
+ value = os.environ.get(key)
+ if value is None:
+ continue
+ try:
+ return int(value)
+ except ValueError:
+ continue
+ return None
+
+
+def _infer_distributed_ranks():
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ try:
+ return torch.distributed.get_rank(), torch.distributed.get_world_size()
+ except Exception:
+ pass
+ return _get_env_int(LOCAL_RANK_KEYS), _get_env_int(WORLD_SIZE_KEYS)
+
+
+def is_distributed():
+ rank, world_size = _infer_distributed_ranks()
+ return (world_size or 1) > 1 or (rank is not None and rank > 0)
+
+
+def prepare_device_map():
+ rank, world_size = _infer_distributed_ranks()
+ distributed = (world_size or 1) > 1 or (rank is not None and rank > 0)
+ if not distributed:
+ return None, False
+
+ local_rank = 0 if rank is None else rank
+ device_map = {"": f"{DEVICE_TYPE_TORCH}:{local_rank}"}
+ try:
+ if DEVICE_TYPE_TORCH == "cuda":
+ torch.cuda.set_device(local_rank)
+ elif DEVICE_TYPE_TORCH == "xpu" and hasattr(torch, "xpu"):
+ torch.xpu.set_device(local_rank)
+ except Exception:
+ pass
+ return device_map, True
+
+
def __get_model_name(
model_name,
load_in_4bit = True,
diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py
index 1945e6a74..0f7ea3c0b 100644
--- a/unsloth/models/mistral.py
+++ b/unsloth/models/mistral.py
@@ -17,6 +17,16 @@
from ._utils import __version__
from unsloth_zoo.utils import _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
+from ..utils.packing import (
+ get_packed_info_from_kwargs,
+ mask_packed_sequence_boundaries,
+)
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+)
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -76,6 +86,7 @@ def MistralAttention_fast_forward(
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
@@ -83,12 +94,13 @@ def MistralAttention_fast_forward(
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
-
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)
- if position_ids is None:
- Q, K = fast_rope_embedding(Q, K, cos, sin)
- else:
- Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+
+ rope_position_ids = (
+ position_ids if position_ids is not None else kwargs.get("position_ids")
+ )
+ # Useful for LongRoPE
+ Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -96,68 +108,38 @@ def MistralAttention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
- # Xformers memory efficient attention
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- K_M = V_M = bsz * kv_seq_len
- Q_M = bsz * q_len
-
- has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
-
- # Group query attention
- K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- if hidden_states.requires_grad:
- K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
-
- if has_swa:
- Q = Q.view(1, Q_M, n_heads, head_dim)
- K = K.view(1, K_M, n_heads, head_dim)
- V = V.view(1, V_M, n_heads, head_dim)
- else:
- # Xformers does support the forward pass though
- Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-
- if has_swa:
- Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
- K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
- V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
-
- A = xformers_attention(Q, K, V, attn_bias = causal_mask)
- A = A.view(bsz, q_len, n_heads, head_dim)
-
- elif HAS_FLASH_ATTENTION and attention_mask is None:
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- sw = getattr(self.config, "sliding_window", None)
- sw = kv_seq_len if (sw is None or sw == "null") else sw
- window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
- A = flash_attn_func(Q, K, V, causal = True, window_size = window)
- else:
- # Grouped query attention
- # if n_groups != 1:
- K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
- # pass
- # Must be contiguous or else results are False!
- # https://github.com/pytorch/pytorch/issues/112577
- Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- A = scaled_dot_product_attention(
- Q, K, V, attn_mask = attention_mask, is_causal = False
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2).contiguous()
+ sw_cfg = getattr(self.config, "sliding_window", None)
+ sw = kv_seq_len if (sw_cfg is None or sw_cfg == "null") else sw_cfg
+ window_size = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
+ use_varlen = (
+ seq_info is not None and past_key_value is None and window_size == (-1, -1)
+ )
+ backend = select_attention_backend(use_varlen)
+ attention_config = AttentionConfig(
+ backend = backend,
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {"causal": True, "window_size": window_size},
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "causal": True,
+ "softmax_scale": getattr(self, "softmax_scale", None),
+ },
+ )
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ )
+
+ A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
@@ -376,6 +358,10 @@ def MistralForCausalLM_fast_forward(
shift_labels = torch.empty_like(labels)
shift_labels[..., :-1] = labels[..., 1:]
shift_labels[..., -1] = -100
+ mask_packed_sequence_boundaries(
+ shift_labels,
+ kwargs.get("packed_seq_lengths"),
+ )
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py
index 3d905fe06..d9f54f267 100644
--- a/unsloth/models/qwen3.py
+++ b/unsloth/models/qwen3.py
@@ -16,6 +16,13 @@
import os
from ._utils import __version__
from unsloth_zoo.utils import Version, _get_dtype
+from ..utils.packing import get_packed_info_from_kwargs
+from ..utils.attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ run_attention,
+ select_attention_backend,
+)
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -95,6 +102,7 @@ def Qwen3Attention_fast_forward(
bsz, q_len, n_kv_heads, head_dim
) # .transpose(1, 2) # we will transpose after normalisation
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
+ seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, hidden_states.device)
# Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
# Note that using fast_layernorm_compiled causes issues as the dimensions don't match up.
@@ -110,20 +118,19 @@ def Qwen3Attention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
+ # Extend RoPE dynamically to fit in VRAM
if position_embeddings:
cos, sin = position_embeddings
else:
- # Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
- device_index = Q.device.index
+ cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)
- if position_ids is None:
- # Useful for LongRoPE
- cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
- else:
- cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
- Q, K = fast_rope_embedding(Q, K, cos, sin)
+ rope_position_ids = (
+ position_ids if position_ids is not None else kwargs.get("position_ids")
+ )
+ # Useful for LongRoPE
+ Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
@@ -131,74 +138,32 @@ def Qwen3Attention_fast_forward(
past_key_value = (K, V) if use_cache else None
# Attention module
- if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
- # Xformers memory efficient attention
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- K_M = V_M = bsz * kv_seq_len
- Q_M = bsz * q_len
-
- has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
-
- # Group query attention
- K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
- K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
- if hidden_states.requires_grad:
- K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
-
- if has_swa:
- Q = Q.view(1, Q_M, n_heads, head_dim)
- K = K.view(1, K_M, n_heads, head_dim)
- V = V.view(1, V_M, n_heads, head_dim)
- else:
- # Xformers does support the forward pass though
- Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-
- if has_swa:
- Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
- K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
- V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
-
- A = xformers_attention(Q, K, V, attn_bias = causal_mask)
- A = A.view(bsz, q_len, n_heads, head_dim)
-
- elif HAS_FLASH_ATTENTION and attention_mask is None:
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
- V = V.transpose(1, 2)
- sw = kv_seq_len
- window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
- A = flash_attn_func(Q, K, V, causal = True, window_size = window)
- else:
- # Grouped query attention
- # if n_groups != 1:
- K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
- K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
- # pass
- # Must be contiguous or else results are False!
- # https://github.com/pytorch/pytorch/issues/112577
- Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
- # Needs (batch_size, n_heads, seq_len, head_dim)
- # is_casual and attention_mask must not be both set!
- # when qlen==vlen and attn_mask is None, we should use causal attention
- Q_len = Q.shape[-2]
- K_len = K.shape[-2]
- if attention_mask is None and Q_len == K_len:
- is_causal = True
- else:
- is_causal = False
+ use_varlen = seq_info is not None and past_key_value is None
+ backend = select_attention_backend(use_varlen)
+ attention_config = AttentionConfig(
+ backend = backend,
+ n_kv_heads = n_kv_heads,
+ n_groups = n_groups,
+ flash_dense_kwargs = {"causal": True},
+ flash_varlen_kwargs = {
+ "dropout_p": 0.0,
+ "causal": True,
+ "softmax_scale": getattr(self, "softmax_scale", None),
+ },
+ )
+ context = AttentionContext(
+ bsz = bsz,
+ q_len = q_len,
+ kv_seq_len = kv_seq_len,
+ n_heads = n_heads,
+ head_dim = head_dim,
+ requires_grad = hidden_states.requires_grad,
+ seq_info = seq_info,
+ attention_mask = attention_mask,
+ causal_mask = causal_mask,
+ )
- A = scaled_dot_product_attention(
- Q, K, V, attn_mask = attention_mask, is_causal = is_causal
- )
- # Go back to (batch_size, seq_len, n_heads, head_dim)
- A = A.transpose(1, 2).contiguous()
+ A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
attn_output = self.apply_o(self, attn_output)
diff --git a/unsloth/utils/__init__.py b/unsloth/utils/__init__.py
index e69de29bb..7870a0879 100644
--- a/unsloth/utils/__init__.py
+++ b/unsloth/utils/__init__.py
@@ -0,0 +1,24 @@
+from .packing import configure_sample_packing, enable_sample_packing
+from .attention_dispatch import (
+ AttentionConfig,
+ AttentionContext,
+ FLASH_DENSE,
+ FLASH_VARLEN,
+ SDPA,
+ XFORMERS,
+ run_attention,
+ select_attention_backend,
+)
+
+__all__ = [
+ "configure_sample_packing",
+ "enable_sample_packing",
+ "AttentionConfig",
+ "AttentionContext",
+ "FLASH_VARLEN",
+ "FLASH_DENSE",
+ "XFORMERS",
+ "SDPA",
+ "run_attention",
+ "select_attention_backend",
+]
diff --git a/unsloth/utils/attention_dispatch.py b/unsloth/utils/attention_dispatch.py
new file mode 100644
index 000000000..45aef551e
--- /dev/null
+++ b/unsloth/utils/attention_dispatch.py
@@ -0,0 +1,271 @@
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program. If not, see .
+
+"""Shared helpers for attention backend selection and execution."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Tuple
+
+from torch import Tensor
+from torch.nn.functional import scaled_dot_product_attention
+
+from ..models._utils import *
+from ..utils.packing import (
+ build_sdpa_packed_attention_mask,
+ build_xformers_block_causal_mask,
+)
+
+if HAS_FLASH_ATTENTION:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+HAS_XFORMERS = xformers is not None
+BlockDiagonalCausalMask = None
+if HAS_XFORMERS:
+ BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask
+SDPA_HAS_GQA = "enable_gqa" in (scaled_dot_product_attention.__doc__ or "")
+
+FLASH_VARLEN = "flash_varlen"
+FLASH_DENSE = "flash_dense"
+XFORMERS = "xformers"
+SDPA = "sdpa"
+
+
+XFORMERS_BLOCK_DIAG_CLS = (
+ xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
+)
+
+
+@dataclass
+class AttentionConfig:
+ """
+ Per-layer attention metadata.
+
+ NOTE(djsaunde): I had originally intended this to be populated once per layer, but
+ we're currently constructing it on every forward pass since it can possibly be
+ invalid from one forward pass to the next (e.g., switching from training to
+ inference). For now, I'm keeping separate from AttentionContext for the sake of
+ better grouping of params.
+ """
+
+ backend: str
+ n_kv_heads: int
+ n_groups: int
+ flash_dense_kwargs: Optional[dict[str, Any]] = None
+ flash_varlen_kwargs: Optional[dict[str, Any]] = None
+ sdpa_kwargs: Optional[dict[str, Any]] = None
+ xformers_kwargs: Optional[dict[str, Any]] = None
+
+
+@dataclass
+class AttentionContext:
+ """Per-call info required to run attention."""
+
+ bsz: int
+ q_len: int
+ kv_seq_len: int
+ n_heads: int
+ head_dim: int
+ requires_grad: bool
+ seq_info: Optional[Tuple[Tensor, Tensor, int]]
+ attention_mask: Optional[Tensor]
+ causal_mask: Optional[Any]
+ sliding_window: Optional[int] = None
+
+
+def select_attention_backend(use_varlen: bool = False) -> str:
+ """Return attention backend based on availability / priority order."""
+
+ if HAS_FLASH_ATTENTION:
+ if use_varlen:
+ return FLASH_VARLEN
+ else:
+ return FLASH_DENSE
+ if HAS_XFORMERS:
+ return XFORMERS
+ return SDPA
+
+
+def run_attention(
+ *,
+ config: AttentionConfig,
+ context: AttentionContext,
+ Q: Tensor,
+ K: Tensor,
+ V: Tensor,
+) -> Tensor:
+ """Run attention using config / context info."""
+
+ backend = config.backend
+ if backend == FLASH_VARLEN and context.seq_info is None:
+ backend = FLASH_DENSE if HAS_FLASH_ATTENTION else SDPA
+ flash_dense_kwargs = config.flash_dense_kwargs or {}
+ flash_varlen_kwargs = config.flash_varlen_kwargs or {}
+ sdpa_kwargs = config.sdpa_kwargs or {}
+ xformers_kwargs = config.xformers_kwargs or {}
+
+ bsz = context.bsz
+ n_heads = context.n_heads
+ q_len = context.q_len
+ head_dim = context.head_dim
+ kv_seq_len = context.kv_seq_len
+ requires_grad = context.requires_grad
+ sliding_window = context.sliding_window
+
+ if backend == FLASH_VARLEN:
+ Q_f = Q.transpose(1, 2).reshape(bsz * q_len, n_heads, head_dim)
+ K_f = K.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim)
+ V_f = V.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim)
+ _, cu_seqlens, max_seqlen = context.seq_info
+ return flash_attn_varlen_func(
+ Q_f,
+ K_f,
+ V_f,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ **flash_varlen_kwargs,
+ ).view(bsz, q_len, n_heads, head_dim)
+ elif backend == FLASH_DENSE:
+ Q_t = Q.transpose(1, 2)
+ K_t = K.transpose(1, 2)
+ V_t = V.transpose(1, 2)
+ return flash_attn_func(Q_t, K_t, V_t, **flash_dense_kwargs).reshape(
+ bsz, q_len, n_heads, head_dim
+ )
+ elif backend == XFORMERS:
+ attn_bias = build_xformers_block_causal_mask(
+ context.seq_info,
+ sliding_window = sliding_window,
+ base_mask = context.causal_mask,
+ )
+
+ Q_t = Q.transpose(1, 2)
+ K_t = K.transpose(1, 2)
+ V_t = V.transpose(1, 2)
+
+ K_mod = K_t
+ V_mod = V_t
+ Q_mod = Q_t
+
+ if config.n_groups != 1:
+ K_mod = K_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
+ V_mod = V_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
+ K_mod = K_mod.expand(
+ bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
+ )
+ V_mod = V_mod.expand(
+ bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
+ )
+
+ if requires_grad:
+ K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ V_mod = V_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ else:
+ Q_mod = Q_t.view(
+ bsz, q_len, config.n_kv_heads, config.n_groups, head_dim
+ )
+
+ has_block = XFORMERS_BLOCK_DIAG_CLS is not None and isinstance(
+ attn_bias, XFORMERS_BLOCK_DIAG_CLS
+ )
+
+ if config.n_groups != 1 and not requires_grad and has_block:
+ Q_mod = Q_mod.view(
+ 1, bsz * q_len, config.n_kv_heads, config.n_groups, head_dim
+ )
+ K_mod = K_mod.view(
+ 1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
+ )
+ V_mod = V_mod.view(
+ 1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
+ )
+ elif config.n_groups != 1 and requires_grad and has_block:
+ Q_mod = Q_mod.view(1, bsz * q_len, n_heads, head_dim)
+ K_mod = K_mod.view(1, bsz * kv_seq_len, n_heads, head_dim)
+ V_mod = V_mod.view(1, bsz * kv_seq_len, n_heads, head_dim)
+
+ out = xformers_attention(
+ Q_mod,
+ K_mod,
+ V_mod,
+ attn_bias = attn_bias,
+ **xformers_kwargs,
+ )
+
+ if config.n_groups != 1 and not requires_grad:
+ if has_block:
+ out = out.view(bsz, q_len, config.n_kv_heads, config.n_groups, head_dim)
+ else:
+ out = out.view(bsz, q_len, config.n_kv_heads, config.n_groups, head_dim)
+ out = out.reshape(bsz, q_len, n_heads, head_dim)
+ else:
+ if has_block:
+ out = out.view(bsz, q_len, n_heads, head_dim)
+ else:
+ out = out.view(bsz, q_len, n_heads, head_dim)
+ return out
+ else:
+ local_mask = context.attention_mask
+ is_causal_local = False
+ if context.seq_info is not None and local_mask is None:
+ local_mask = build_sdpa_packed_attention_mask(
+ context.seq_info,
+ dtype = Q.dtype,
+ device = Q.device,
+ sliding_window = sliding_window,
+ )
+ else:
+ q_len_local = Q.shape[-2]
+ k_len_local = K.shape[-2]
+ is_causal_local = local_mask is None and q_len_local == k_len_local
+
+ kwargs = dict(sdpa_kwargs)
+ kwargs.setdefault("attn_mask", local_mask)
+ kwargs.setdefault("is_causal", is_causal_local)
+
+ if SDPA_HAS_GQA:
+ kwargs.setdefault("enable_gqa", config.n_groups != 1)
+ out = scaled_dot_product_attention(Q, K, V, **kwargs)
+ return out.transpose(1, 2)
+
+ K_mod = K
+ V_mod = V
+ if config.n_groups != 1:
+ K_mod = K[:, :, None, :, :].expand(
+ bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim
+ )
+ V_mod = V[:, :, None, :, :].expand(
+ bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim
+ )
+ K_mod = K_mod.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ V_mod = V_mod.reshape(bsz, n_heads, kv_seq_len, head_dim)
+
+ out = scaled_dot_product_attention(
+ Q.contiguous(),
+ K_mod.contiguous(),
+ V_mod.contiguous(),
+ **kwargs,
+ )
+ return out.transpose(1, 2).contiguous()
+
+
+__all__ = [
+ "AttentionConfig",
+ "AttentionContext",
+ "select_attention_backend",
+ "run_attention",
+]
diff --git a/unsloth/utils/packing.py b/unsloth/utils/packing.py
new file mode 100644
index 000000000..26f8542f4
--- /dev/null
+++ b/unsloth/utils/packing.py
@@ -0,0 +1,253 @@
+# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this program. If not, see .
+
+"""Utilities for enabling packed (padding-free) batches across Unsloth."""
+
+from __future__ import annotations
+
+import logging
+from typing import Any, Iterable, Optional, Sequence, Tuple
+
+import torch
+
+try:
+ from xformers.ops.fmha.attn_bias import (
+ BlockDiagonalCausalMask as _XFormersBlockMask,
+ )
+except Exception:
+ try:
+ from xformers.attn_bias import BlockDiagonalCausalMask as _XFormersBlockMask
+ except Exception:
+ _XFormersBlockMask = None
+
+
+class _TrlPackingWarningFilter(logging.Filter):
+ _NEEDLES = (
+ "Padding-free training is enabled, but the attention implementation is not set to 'flash_attention_2'",
+ "You are using packing, but the attention implementation is not set to 'flash_attention_2' or 'kernels-community/vllm-flash-attn3'",
+ )
+
+ def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - trivial
+ message = record.getMessage()
+ return not any(needle in message for needle in self._NEEDLES)
+
+
+_TRL_FILTER_INSTALLED = False
+
+
+def _ensure_trl_warning_filter():
+ global _TRL_FILTER_INSTALLED
+ if _TRL_FILTER_INSTALLED:
+ return
+ logging.getLogger("trl.trainer.sft_trainer").addFilter(_TrlPackingWarningFilter())
+ _TRL_FILTER_INSTALLED = True
+
+
+def configure_sample_packing(config):
+ """Mutate an ``SFTConfig`` so TRL prepares packed batches."""
+ _ensure_trl_warning_filter()
+ setattr(config, "packing", True)
+ setattr(config, "padding_free", True)
+ setattr(config, "remove_unused_columns", False)
+
+
+def enable_sample_packing(model, trainer):
+ """Enable runtime support for packed batches on an existing trainer."""
+
+ def _mark_allow_overlength(module):
+ if hasattr(module, "max_seq_length"):
+ setattr(module, "_unsloth_allow_packed_overlength", True)
+ for child in module.children():
+ _mark_allow_overlength(child)
+
+ _mark_allow_overlength(model)
+
+ collator = getattr(trainer, "data_collator", None)
+ if (
+ collator is None
+ or not hasattr(collator, "torch_call")
+ or getattr(collator, "_unsloth_packing_wrapped", False)
+ ):
+ return
+
+ if hasattr(collator, "return_position_ids"):
+ collator.return_position_ids = True
+
+ original_torch_call = collator.torch_call
+
+ def torch_call_with_lengths(examples: Sequence[dict]):
+ batch = original_torch_call(examples)
+ if examples and isinstance(examples[0], dict):
+ seq_lengths: list[int] = []
+ for example in examples:
+ seq_lengths.extend(example["seq_lengths"])
+ if seq_lengths:
+ batch["packed_seq_lengths"] = torch.tensor(
+ seq_lengths, dtype = torch.int32
+ )
+ return batch
+
+ collator.torch_call = torch_call_with_lengths
+ collator._unsloth_packing_wrapped = True
+
+
+def get_packed_info_from_kwargs(
+ kwargs: dict,
+ total_tokens: int,
+ device: torch.device,
+) -> Optional[Tuple[torch.Tensor, torch.Tensor, int]]:
+ """Extract packed sequence information from attention kwargs."""
+
+ seq_lengths = kwargs.get("packed_seq_lengths")
+ if seq_lengths is None:
+ return None
+
+ if isinstance(seq_lengths, torch.Tensor):
+ lengths = seq_lengths.to(device = device, dtype = torch.int32)
+ else:
+ lengths = torch.tensor(seq_lengths, device = device, dtype = torch.int32)
+
+ if lengths.ndim > 1:
+ lengths = lengths.reshape(-1)
+
+ if lengths.numel() == 0:
+ return None
+
+ if int(lengths.sum().item()) != total_tokens:
+ return None
+
+ cu_seqlens = torch.cat(
+ [
+ torch.zeros(1, dtype = torch.int32, device = device),
+ torch.cumsum(lengths, dim = 0, dtype = torch.int32),
+ ]
+ )
+ max_seqlen = int(lengths.max().item())
+ return lengths, cu_seqlens, max_seqlen
+
+
+def build_xformers_block_causal_mask(
+ seq_info: Optional[Tuple[torch.Tensor, torch.Tensor, int]],
+ *,
+ sliding_window: Optional[int] = None,
+ base_mask: Optional[Any] = None,
+):
+ if _XFormersBlockMask is None:
+ return None
+ if seq_info is not None:
+ seq_lengths, _, _ = seq_info
+ lengths = seq_lengths.to("cpu", torch.int32).tolist()
+ if not lengths:
+ return None
+ mask = _XFormersBlockMask.from_seqlens(lengths)
+ else:
+ mask = base_mask
+
+ if (
+ sliding_window is not None
+ and sliding_window > 0
+ and mask is not None
+ and hasattr(mask, "make_local_attention")
+ ):
+ mask = mask.make_local_attention(window_size = sliding_window)
+ return mask
+
+
+def build_sdpa_packed_attention_mask(
+ seq_info: Tuple[torch.Tensor, torch.Tensor, int],
+ *,
+ dtype: torch.dtype,
+ device: torch.device,
+ sliding_window: Optional[int] = None,
+) -> torch.Tensor:
+ seq_lengths, _, _ = seq_info
+ total_tokens = int(seq_lengths.sum().item())
+ mask = torch.full(
+ (total_tokens, total_tokens),
+ float("-inf"),
+ dtype = dtype,
+ device = device,
+ )
+ offset = 0
+ for length in seq_lengths.tolist():
+ length = int(length)
+ if length <= 0:
+ continue
+ block = torch.zeros((length, length), dtype = dtype, device = device)
+ upper = torch.triu(
+ torch.ones((length, length), device = device), diagonal = 1
+ ).bool()
+ block = block.masked_fill(upper, float("-inf"))
+ if (
+ sliding_window is not None
+ and sliding_window > 0
+ and length > sliding_window
+ ):
+ idx = torch.arange(length, device = device)
+ dist = idx.unsqueeze(1) - idx.unsqueeze(0)
+ window_mask = dist >= sliding_window
+ block = block.masked_fill(window_mask, float("-inf"))
+ mask[offset : offset + length, offset : offset + length] = block
+ offset += length
+ return mask.unsqueeze(0).unsqueeze(0)
+
+
+def _normalize_packed_lengths(
+ seq_lengths: Any,
+ *,
+ device: torch.device,
+) -> Optional[torch.Tensor]:
+ if seq_lengths is None:
+ return None
+ if isinstance(seq_lengths, torch.Tensor):
+ lengths = seq_lengths.to(device = device, dtype = torch.int64)
+ else:
+ lengths = torch.tensor(seq_lengths, device = device, dtype = torch.int64)
+ if lengths.ndim != 1:
+ lengths = lengths.reshape(-1)
+ if lengths.numel() == 0:
+ return None
+ return lengths
+
+
+def mask_packed_sequence_boundaries(
+ shift_labels: torch.Tensor,
+ seq_lengths: Any,
+ *,
+ ignore_index: int = -100,
+) -> bool:
+ """Mark final token of every packed sample so CE ignores boundary predictions."""
+
+ lengths = _normalize_packed_lengths(seq_lengths, device = shift_labels.device)
+ if lengths is None:
+ return False
+
+ flat = shift_labels.reshape(-1)
+ total_tokens = flat.shape[0]
+ boundary_positions = torch.cumsum(lengths, dim = 0) - 1
+ valid = boundary_positions < total_tokens
+ if not torch.all(valid):
+ boundary_positions = boundary_positions[valid]
+ if boundary_positions.numel() == 0:
+ return False
+ flat[boundary_positions] = ignore_index
+ return True
+
+
+__all__ = [
+ "configure_sample_packing",
+ "enable_sample_packing",
+ "mask_packed_sequence_boundaries",
+]