diff --git a/tests/utils/test_attention_masks.py b/tests/utils/test_attention_masks.py new file mode 100644 index 000000000..dbd541961 --- /dev/null +++ b/tests/utils/test_attention_masks.py @@ -0,0 +1,272 @@ +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +"""Unit tests for packed-attention mask helpers with sliding-window logic.""" + +import math + +import torch + +from unsloth.utils import attention_dispatch +from unsloth.utils import packing as packing_utils + + +def _make_seq_info(lengths): + lengths = torch.tensor(lengths, dtype = torch.int32) + cu = torch.cat( + [ + torch.zeros(1, dtype = torch.int32), + torch.cumsum(lengths, dim = 0, dtype = torch.int32), + ] + ) + max_len = int(lengths.max().item()) + return lengths, cu, max_len + + +def test_sdpa_packed_attention_mask_sliding_window(): + seq_info = _make_seq_info([5, 3]) + mask = packing_utils.build_sdpa_packed_attention_mask( + seq_info, + dtype = torch.float32, + device = torch.device("cpu"), + sliding_window = 3, + ) + + assert mask.shape == (1, 1, 8, 8) + + block_first = mask[0, 0, :5, :5] + upper = torch.triu(torch.ones_like(block_first), diagonal = 1).bool() + assert torch.all(block_first[upper] == float("-inf")) + assert block_first[3, 0].item() == float("-inf") + assert block_first[4, 1].item() == float("-inf") + assert block_first[4, 2].item() > -math.inf + assert mask[0, 0, 0, 6].item() == float("-inf") + + +def test_xformers_block_mask_sliding_window(monkeypatch): + class _FakeMask: + def __init__(self, lengths, window = None): + self.lengths = lengths + self.window = window + + @classmethod + def from_seqlens(cls, lengths): + return cls(tuple(lengths)) + + def make_local_attention(self, window_size): + return _FakeMask(self.lengths, window = window_size) + + monkeypatch.setattr(packing_utils, "_XFormersBlockMask", _FakeMask, raising = False) + + seq_info = _make_seq_info([4, 4]) + mask = packing_utils.build_xformers_block_causal_mask( + seq_info, + sliding_window = 2, + ) + + assert isinstance(mask, _FakeMask) + assert mask.window == 2 + + +def test_run_attention_sdpa_passes_sliding_window(monkeypatch): + seq_info = _make_seq_info([3, 2]) + sliding_window = 2 + + original_builder = attention_dispatch.build_sdpa_packed_attention_mask + captured = {} + + def _capture_builder(seq_info_arg, *, dtype, device, sliding_window = None): + captured["window"] = sliding_window + return original_builder( + seq_info_arg, + dtype = dtype, + device = device, + sliding_window = sliding_window, + ) + + monkeypatch.setattr( + attention_dispatch, + "build_sdpa_packed_attention_mask", + _capture_builder, + ) + + def _fake_sdpa(Q, K, V, **kwargs): + captured["mask"] = kwargs.get("attn_mask") + return torch.zeros_like(Q) + + monkeypatch.setattr(attention_dispatch, "scaled_dot_product_attention", _fake_sdpa) + + config = attention_dispatch.AttentionConfig( + backend = attention_dispatch.SDPA, + n_kv_heads = 1, + n_groups = 1, + ) + + context = attention_dispatch.AttentionContext( + bsz = 1, + q_len = 5, + kv_seq_len = 5, + n_heads = 1, + head_dim = 1, + requires_grad = False, + seq_info = seq_info, + attention_mask = None, + causal_mask = None, + sliding_window = sliding_window, + ) + + Q = torch.zeros(1, 1, 5, 1) + K = torch.zeros_like(Q) + V = torch.zeros_like(Q) + + attention_dispatch.run_attention( + config = config, + context = context, + Q = Q, + K = K, + V = V, + ) + + assert captured["window"] == sliding_window + mask = captured["mask"] + assert mask is not None and mask.shape == (1, 1, 5, 5) + assert mask[0, 0, 4, 1].item() == float("-inf") + + +def test_run_attention_xformers_passes_sliding_window(monkeypatch): + seq_info = _make_seq_info([4]) + sliding_window = 3 + + class _FakeBias: + pass + + captured = {} + + def _fake_builder(seq_info_arg, *, sliding_window = None, base_mask = None): + captured["window"] = sliding_window + captured["base"] = base_mask + return _FakeBias() + + def _fake_attention(Q, K, V, attn_bias = None, **_): + captured["bias"] = attn_bias + return torch.zeros_like(Q) + + monkeypatch.setattr( + attention_dispatch, "build_xformers_block_causal_mask", _fake_builder + ) + monkeypatch.setattr( + attention_dispatch, "xformers_attention", _fake_attention, raising = False + ) + monkeypatch.setattr( + attention_dispatch, "XFORMERS_BLOCK_DIAG_CLS", _FakeBias, raising = False + ) + + config = attention_dispatch.AttentionConfig( + backend = attention_dispatch.XFORMERS, + n_kv_heads = 1, + n_groups = 1, + ) + + context = attention_dispatch.AttentionContext( + bsz = 1, + q_len = 4, + kv_seq_len = 4, + n_heads = 1, + head_dim = 1, + requires_grad = False, + seq_info = seq_info, + attention_mask = None, + causal_mask = None, + sliding_window = sliding_window, + ) + + Q = torch.zeros(1, 1, 4, 1) + K = torch.zeros_like(Q) + V = torch.zeros_like(Q) + + attention_dispatch.run_attention( + config = config, + context = context, + Q = Q, + K = K, + V = V, + ) + + assert captured["window"] == sliding_window + assert isinstance(captured["bias"], _FakeBias) + + +def test_run_attention_flash_varlen_receives_window_and_softcap(monkeypatch): + seq_info = _make_seq_info([4]) + sliding_window = 3 + softcap = 0.5 + window_tuple = (sliding_window, sliding_window) + + captured = {} + + def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs): + captured["kwargs"] = kwargs + return torch.zeros_like(Q) + + monkeypatch.setattr( + attention_dispatch, + "flash_attn_varlen_func", + _fake_flash_varlen, + ) + monkeypatch.setattr(attention_dispatch, "HAS_FLASH_ATTENTION", True) + + config = attention_dispatch.AttentionConfig( + backend = attention_dispatch.FLASH_VARLEN, + n_kv_heads = 1, + n_groups = 1, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "softmax_scale": 1.0, + "causal": True, + "softcap": softcap, + "window_size": window_tuple, + }, + ) + + context = attention_dispatch.AttentionContext( + bsz = 1, + q_len = 4, + kv_seq_len = 4, + n_heads = 1, + head_dim = 2, + requires_grad = False, + seq_info = seq_info, + attention_mask = None, + causal_mask = None, + sliding_window = sliding_window, + ) + + Q = torch.zeros(1, 1, 4, 2) + K = torch.zeros_like(Q) + V = torch.zeros_like(Q) + + attention_dispatch.run_attention( + config = config, + context = context, + Q = Q, + K = K, + V = V, + ) + + assert captured["kwargs"]["softcap"] == softcap + assert captured["kwargs"]["window_size"] == window_tuple + + +"""Unit tests for packed-attention mask helpers with sliding-window logic.""" diff --git a/tests/utils/test_packing.py b/tests/utils/test_packing.py new file mode 100644 index 000000000..e9c3264ac --- /dev/null +++ b/tests/utils/test_packing.py @@ -0,0 +1,341 @@ +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +from unsloth import FastLanguageModel +from unsloth.utils import attention_dispatch as attention_dispatch_utils +from unsloth.utils.packing import ( + configure_sample_packing, + enable_sample_packing, + mask_packed_sequence_boundaries, +) + +from collections.abc import Iterable +from contextlib import ExitStack +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +import torch +from datasets import Dataset +from trl import SFTConfig, SFTTrainer +from trl.trainer.sft_trainer import DataCollatorForLanguageModeling + + +def _build_packed_training_setup(tmp_path, device): + dtype = None + if device.type == "cuda": + if torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + else: + dtype = torch.float16 + + try: + model, tokenizer = FastLanguageModel.from_pretrained( + model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM", + max_seq_length = 64, + load_in_4bit = False, + dtype = dtype, + ) + except OSError as exc: # pragma: no cover - offline CI + pytest.skip(f"Requires access to tiny llama checkpoint: {exc}") + + model.to(device) + + dataset = Dataset.from_dict( + { + "text": [ + "Hello world!", + "Short sample.", + "This is a slightly longer packed example to test batching.", + "Another response to include in the batch.", + ] + } + ) + + training_args = SFTConfig( + per_device_train_batch_size = 1, + per_device_eval_batch_size = 1, + gradient_accumulation_steps = 1, + dataset_text_field = "text", + max_length = 64, + logging_steps = 1, + max_steps = 1, + fp16 = device.type == "cuda" and not torch.cuda.is_bf16_supported(), + bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported(), + dataset_num_proc = 1, + output_dir = str(tmp_path), + ) + configure_sample_packing(training_args) + + trainer = SFTTrainer( + model = model, + processing_class = tokenizer, + train_dataset = dataset, + args = training_args, + ) + + enable_sample_packing(model, trainer) + + dataloader = trainer.get_train_dataloader() + batch = next(iter(dataloader)) + + model_device = next(model.parameters()).device + + for key, value in list(batch.items()): + if torch.is_tensor(value): + batch[key] = value.to(model_device) + + from unsloth.models import llama as llama_mod + + return model, batch, trainer, llama_mod + + +def _trim_batch_to_total_tokens(data, total_tokens): + def _trim_tensor(t: torch.Tensor): + if t.ndim >= 2 and t.size(1) > total_tokens: + return t[:, :total_tokens].contiguous() + return t + + trimmed = {} + for key, value in data.items(): + if torch.is_tensor(value): + trimmed[key] = _trim_tensor(value) + else: + trimmed[key] = value + return trimmed + + +def test_mask_packed_sequence_boundaries_marks_single_row(): + shift_labels = torch.arange(6, dtype = torch.long).view(1, 6) + changed = mask_packed_sequence_boundaries( + shift_labels, + torch.tensor([2, 1, 3], dtype = torch.int32), + ) + assert changed is True + flat = shift_labels.view(-1) + assert flat[1].item() == -100 + assert flat[2].item() == -100 + assert flat[5].item() == -100 + assert flat[0].item() != -100 + + +def test_mask_packed_sequence_boundaries_across_multiple_rows(): + shift_labels = torch.arange(10, dtype = torch.long).view(2, 5) + lengths = torch.tensor([3, 2, 4, 1], dtype = torch.int32) + changed = mask_packed_sequence_boundaries(shift_labels, lengths) + assert changed is True + flat = shift_labels.view(-1) + for idx in (2, 4, 8, 9): + assert flat[idx].item() == -100 + assert torch.any(flat != -100) + + +def test_configure_sample_packing(): + config = SimpleNamespace() + configure_sample_packing(config) + + assert config.packing is True + assert config.padding_free is True + assert config.remove_unused_columns is False + + +class _DummyChild(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_seq_length = 8 + + +class _DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_seq_length = 16 + self.child = _DummyChild() + self.config = SimpleNamespace(_attn_implementation = "sdpa") + self.generation_config = SimpleNamespace(attn_implementation = "sdpa") + + +class _DummyTrainer: + def __init__(self): + self.args = SimpleNamespace(remove_unused_columns = True) + self.data_collator = DataCollatorForLanguageModeling( + pad_token_id = 0, + completion_only_loss = False, + padding_free = True, + return_position_ids = False, + return_tensors = "pt", + ) + + +def test_enable_sample_packing(): + model = _DummyModel() + trainer = _DummyTrainer() + + enable_sample_packing(model, trainer) + + # model hierarchy should now allow packed overlength inputs + assert getattr(model, "_unsloth_allow_packed_overlength") is True + assert getattr(model.child, "_unsloth_allow_packed_overlength") is True + + collator = trainer.data_collator + assert collator.return_position_ids is True + assert getattr(collator, "_unsloth_packing_wrapped") is True + + examples = [ + { + "input_ids": [0, 1, 2], + "labels": [0, 1, 2], + "seq_lengths": [2, 1], + }, + { + "input_ids": [3, 4, 5], + "labels": [3, 4, 5], + "seq_lengths": [3], + }, + ] + batch = collator.torch_call(examples) + + # packed lengths are aggregated into a single tensor + assert "packed_seq_lengths" in batch + assert torch.equal( + batch["packed_seq_lengths"], + torch.tensor([2, 1, 3], dtype = torch.int32), + ) + + assert batch["input_ids"].shape == (1, 6) + expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long) + assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions) + + +def test_enable_sample_packing_trl_collator(tmp_path): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model, _, trainer, _ = _build_packed_training_setup(tmp_path, device) + + enable_sample_packing(model, trainer) + + examples = [ + { + "input_ids": [0, 1, 2], + "labels": [0, 1, 2], + "seq_lengths": [2, 1], + }, + { + "input_ids": [3, 4, 5], + "labels": [3, 4, 5], + "seq_lengths": [3], + }, + ] + + batch = trainer.data_collator.torch_call(examples) + + assert batch["input_ids"].shape == (1, 6) + assert torch.equal( + batch["packed_seq_lengths"], + torch.tensor([2, 1, 3], dtype = torch.int32), + ) + + expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long) + assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions) + + if hasattr(trainer, "accelerator"): + trainer.accelerator.free_memory() + + +def test_packing_sdpa(tmp_path): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device) + + assert "packed_seq_lengths" in batch + assert "attention_mask" not in batch + assert batch["packed_seq_lengths"].dtype == torch.int32 + + total_tokens = batch["input_ids"].size(-1) + assert int(batch["packed_seq_lengths"].sum().item()) == total_tokens + + packed_tokens = int(batch["packed_seq_lengths"].sum().item()) + assert "position_ids" in batch + flat_positions = batch["position_ids"].reshape(-1)[:packed_tokens] + expected_positions = torch.cat( + [ + torch.arange(length, dtype = torch.long) + for length in batch["packed_seq_lengths"].tolist() + ] + ) + assert torch.equal(flat_positions.cpu(), expected_positions) + inputs = _trim_batch_to_total_tokens(batch, packed_tokens) + + seq_info = llama_mod.get_packed_info_from_kwargs( + {"packed_seq_lengths": batch["packed_seq_lengths"]}, + inputs["input_ids"].shape[0] * inputs["input_ids"].shape[1], + inputs["input_ids"].device, + ) + assert seq_info is not None + + original_mask = attention_dispatch_utils.build_sdpa_packed_attention_mask + mask_calls = [] + captured_loss_labels = {} + + def _capture_mask(seq_info, dtype, device, *, sliding_window = None): + mask_calls.append(tuple(seq_info[0].tolist())) + return original_mask( + seq_info, + dtype = dtype, + device = device, + sliding_window = sliding_window, + ) + + def _capture_loss(*, logits, labels, **loss_kwargs): + captured_loss_labels["labels"] = labels.detach().to("cpu") + return torch.zeros((), device = logits.device, dtype = logits.dtype) + + with ExitStack() as stack: + stack.enter_context( + patch.object(attention_dispatch_utils, "HAS_FLASH_ATTENTION", False) + ) + stack.enter_context( + patch.object(attention_dispatch_utils, "HAS_XFORMERS", False) + ) + stack.enter_context( + patch.object( + attention_dispatch_utils, + "build_sdpa_packed_attention_mask", + side_effect = _capture_mask, + ) + ) + stack.enter_context( + patch.object( + llama_mod, + "fast_cross_entropy_loss", + side_effect = _capture_loss, + ) + ) + with torch.no_grad(): + outputs = model(**inputs) + + assert mask_calls, "SDPA packed mask was not constructed" + assert outputs.loss is not None + assert "labels" in captured_loss_labels + flat_loss_labels = captured_loss_labels["labels"].reshape(-1) + boundaries = ( + torch.cumsum( + batch["packed_seq_lengths"].to(device = "cpu", dtype = torch.long), dim = 0 + ) + - 1 + ) + for idx in boundaries.tolist(): + assert flat_loss_labels[idx].item() == -100 + assert torch.any(flat_loss_labels != -100) + + if hasattr(trainer, "accelerator"): + trainer.accelerator.free_memory() diff --git a/unsloth-cli.py b/unsloth-cli.py index fb6e39266..358e61fae 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -34,29 +34,30 @@ def run(args): - import torch from unsloth import FastLanguageModel from datasets import load_dataset from transformers.utils import strtobool from trl import SFTTrainer, SFTConfig - from transformers import TrainingArguments from unsloth import is_bfloat16_supported + from unsloth.utils import configure_sample_packing, enable_sample_packing + from unsloth.models.loader_utils import prepare_device_map import logging logging.getLogger("hf-to-gguf").setLevel(logging.WARNING) # Load model and tokenizer + device_map, distributed = prepare_device_map() model, tokenizer = FastLanguageModel.from_pretrained( model_name = args.model_name, max_seq_length = args.max_seq_length, dtype = args.dtype, load_in_4bit = args.load_in_4bit, + device_map = device_map, ) # Configure PEFT model model = FastLanguageModel.get_peft_model( model, - r = args.r, target_modules = [ "q_proj", "k_proj", @@ -112,6 +113,7 @@ def formatting_prompts_func(examples): # Configure training arguments training_args = SFTConfig( per_device_train_batch_size = args.per_device_train_batch_size, + per_device_eval_batch_size = args.per_device_eval_batch_size, gradient_accumulation_steps = args.gradient_accumulation_steps, warmup_steps = args.warmup_steps, max_steps = args.max_steps, @@ -127,9 +129,15 @@ def formatting_prompts_func(examples): report_to = args.report_to, max_length = args.max_seq_length, dataset_num_proc = 2, - packing = False, + ddp_find_unused_parameters = False if distributed else None, ) + if args.sample_packing: + print( + f"Unsloth Packing: Sample packing enabled (max_length={args.max_seq_length})." + ) + configure_sample_packing(training_args) + # Initialize trainer trainer = SFTTrainer( model = model, @@ -138,8 +146,10 @@ def formatting_prompts_func(examples): args = training_args, ) - # Train model - trainer_stats = trainer.train() + if args.sample_packing: + enable_sample_packing(model, trainer) + + trainer.train() # Save model if args.save_model: @@ -164,13 +174,15 @@ def formatting_prompts_func(examples): else: print(f"Saving model with quantization method: {args.quantization}") model.save_pretrained_gguf( - args.save_path, tokenizer, quantization_method = args.quantization + args.save_path, + tokenizer, + quantization_method = args.quantization, ) if args.push_model: model.push_to_hub_gguf( hub_path = args.hub_path, hub_token = args.hub_token, - quantization_method = quantization_method, + quantization_method = args.quantization, ) else: model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) @@ -181,7 +193,6 @@ def formatting_prompts_func(examples): if __name__ == "__main__": - # Define argument parser parser = argparse.ArgumentParser( description = "🦥 Fine-tune your llm faster using unsloth!" ) @@ -218,7 +229,8 @@ def formatting_prompts_func(examples): ) lora_group = parser.add_argument_group( - "🧠 LoRA Options", "These options are used to configure the LoRA model." + "🧠 LoRA Options", + "These options are used to configure the LoRA model.", ) lora_group.add_argument( "--r", @@ -239,7 +251,10 @@ def formatting_prompts_func(examples): help = "LoRA dropout rate, default is 0.0 which is optimized.", ) lora_group.add_argument( - "--bias", type = str, default = "none", help = "Bias setting for LoRA" + "--bias", + type = str, + default = "none", + help = "Bias setting for LoRA", ) lora_group.add_argument( "--use_gradient_checkpointing", @@ -254,10 +269,15 @@ def formatting_prompts_func(examples): help = "Random state for reproducibility, default is 3407.", ) lora_group.add_argument( - "--use_rslora", action = "store_true", help = "Use rank stabilized LoRA" + "--use_rslora", + action = "store_true", + help = "Use rank stabilized LoRA", ) lora_group.add_argument( - "--loftq_config", type = str, default = None, help = "Configuration for LoftQ" + "--loftq_config", + type = str, + default = None, + help = "Configuration for LoftQ", ) training_group = parser.add_argument_group("🎓 Training Options") @@ -267,6 +287,12 @@ def formatting_prompts_func(examples): default = 2, help = "Batch size per device during training, default is 2.", ) + training_group.add_argument( + "--per_device_eval_batch_size", + type = int, + default = 4, + help = "Batch size per device during evaluation, default is 4.", + ) training_group.add_argument( "--gradient_accumulation_steps", type = int, @@ -280,7 +306,10 @@ def formatting_prompts_func(examples): help = "Number of warmup steps, default is 5.", ) training_group.add_argument( - "--max_steps", type = int, default = 400, help = "Maximum number of training steps." + "--max_steps", + type = int, + default = 400, + help = "Maximum number of training steps.", ) training_group.add_argument( "--learning_rate", @@ -289,7 +318,10 @@ def formatting_prompts_func(examples): help = "Learning rate, default is 2e-4.", ) training_group.add_argument( - "--optim", type = str, default = "adamw_8bit", help = "Optimizer type." + "--optim", + type = str, + default = "adamw_8bit", + help = "Optimizer type.", ) training_group.add_argument( "--weight_decay", @@ -309,8 +341,12 @@ def formatting_prompts_func(examples): default = 3407, help = "Seed for reproducibility, default is 3407.", ) + training_group.add_argument( + "--sample_packing", + action = "store_true", + help = "Enable cross-example packing using TRL's dataset bin packing.", + ) - # Report/Logging arguments report_group = parser.add_argument_group("📊 Report Options") report_group.add_argument( "--report_to", @@ -331,19 +367,31 @@ def formatting_prompts_func(examples): "all", "none", ], - help = "The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.", + help = ( + "The list of integrations to report the results and logs to. Supported platforms are:\n\t\t " + "'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', " + "'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations " + "installed, 'none' for no integrations." + ), ) report_group.add_argument( - "--logging_steps", type = int, default = 1, help = "Logging steps, default is 1" + "--logging_steps", + type = int, + default = 1, + help = "Logging steps, default is 1", ) - # Saving and pushing arguments save_group = parser.add_argument_group("💾 Save Model Options") save_group.add_argument( - "--output_dir", type = str, default = "outputs", help = "Output directory" + "--output_dir", + type = str, + default = "outputs", + help = "Output directory", ) save_group.add_argument( - "--save_model", action = "store_true", help = "Save the model after training" + "--save_model", + action = "store_true", + help = "Save the model after training", ) save_group.add_argument( "--save_method", @@ -358,14 +406,20 @@ def formatting_prompts_func(examples): help = "Convert the model to GGUF after training", ) save_group.add_argument( - "--save_path", type = str, default = "model", help = "Path to save the model" + "--save_path", + type = str, + default = "model", + help = "Path to save the model", ) save_group.add_argument( "--quantization", type = str, default = "q8_0", nargs = "+", - help = "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ", + help = ( + "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), " + "Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf" + ), ) push_group = parser.add_argument_group("🚀 Push Model Options") @@ -386,7 +440,9 @@ def formatting_prompts_func(examples): help = "Path on Hugging Face hub to push the model", ) push_group.add_argument( - "--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub" + "--hub_token", + type = str, + help = "Token for pushing the model to Hugging Face hub", ) args = parser.parse_args() diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 8eecec10c..a33f94016 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -15,8 +15,96 @@ import triton import triton.language as tl import torch +from ..device_type import DEVICE_COUNT from .utils import calculate_settings, torch_gpu_device, torch_device_stream + +@triton.heuristics( + { + "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]), + "HAS_ROPE_INDICES": lambda args: bool(args["HAS_ROPE_INDICES"]), + } +) +@triton.jit +def _rope_embedding_QK( + Q, + Q_batch_stride, + Q_head_stride, + Q_seq_stride, + K, + K_batch_stride, + K_head_stride, + K_seq_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + rope_embedding_indices, + seqlen, + head_dim: tl.constexpr, + n_heads_K: tl.constexpr, + BACKWARD_PASS: tl.constexpr, + HAS_ROPE_INDICES: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_position = tl.program_id(0) + head_position = tl.program_id(1) + col_offsets = tl.arange(0, BLOCK_SIZE) + half_head_dim = head_dim // 2 + mask = col_offsets < half_head_dim + + if HAS_ROPE_INDICES: + rot_position = tl.load( + rope_embedding_indices + row_position, + eviction_policy = "evict_first", + ).to(tl.int32) + else: + rot_position = row_position % seqlen + + cos_ptr = cos + rot_position * cos_row_stride + sin_ptr = sin + rot_position * sin_row_stride + sin1 = tl.load( + sin_ptr + col_offsets, + mask = mask, + other = 0, + eviction_policy = "evict_first", + ) + cos1 = tl.load( + cos_ptr + col_offsets, + mask = mask, + other = 0, + eviction_policy = "evict_first", + ) + if BACKWARD_PASS: + sin1 = -sin1 + + batch_id = row_position // seqlen + seq_index = row_position - batch_id * seqlen + + q_ptr = ( + Q + + batch_id * Q_batch_stride + + head_position * Q_head_stride + + seq_index * Q_seq_stride + ) + q0 = tl.load(q_ptr + col_offsets, mask = mask, other = 0) + q1 = tl.load(q_ptr + half_head_dim + col_offsets, mask = mask, other = 0) + tl.store(q_ptr + col_offsets, q0 * cos1 - q1 * sin1, mask = mask) + tl.store(q_ptr + half_head_dim + col_offsets, q1 * cos1 + q0 * sin1, mask = mask) + + if head_position < n_heads_K: + k_ptr = ( + K + + batch_id * K_batch_stride + + head_position * K_head_stride + + seq_index * K_seq_stride + ) + k0 = tl.load(k_ptr + col_offsets, mask = mask, other = 0) + k1 = tl.load(k_ptr + half_head_dim + col_offsets, mask = mask, other = 0) + tl.store(k_ptr + col_offsets, k0 * cos1 - k1 * sin1, mask = mask) + tl.store(k_ptr + half_head_dim + col_offsets, k1 * cos1 + k0 * sin1, mask = mask) + + ROPE_GROUP_SIZE: int = 4 @@ -102,7 +190,7 @@ def forward(ctx, Q, cos, sin): n_heads: int head_dim: int batch, seq_len, n_heads, head_dim = Q.shape - Q = Q.view(batch * seq_len, n_heads * head_dim) + Q = Q.reshape(batch * seq_len, n_heads * head_dim) n_rows: int n_cols: int n_rows, n_cols = Q.shape @@ -143,7 +231,7 @@ def forward(ctx, Q, cos, sin): ctx.n_groups = n_groups ctx.cos = cos ctx.sin = sin - return Q.view(batch, seq_len, n_heads, head_dim) + return Q.reshape(batch, seq_len, n_heads, head_dim) @staticmethod def backward(ctx, dY): @@ -153,7 +241,6 @@ def backward(ctx, dY): head_dim: int batch, seq_len, n_heads, head_dim = dY.shape dY = dY.reshape(batch * seq_len, n_heads * head_dim) - # Must be reshape not view n_rows: int n_cols: int n_rows, n_cols = dY.shape @@ -181,7 +268,7 @@ def backward(ctx, dY): BLOCK_SIZE = ctx.BLOCK_SIZE, num_warps = ctx.num_warps, ) - dY = dY.view(batch, seq_len, n_heads, head_dim) + dY = dY.reshape(batch, seq_len, n_heads, head_dim) return ( dY, None, @@ -191,12 +278,142 @@ def backward(ctx, dY): # [TODO] Unsure why RoPE Embedding is not torch.compiling properly @torch.compiler.disable -def fast_rope_embedding(Q, K, cos, sin): - Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) - K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) - # synchronize before cat to avoid race condition - torch_device_stream(Q.device).synchronize() - return Q, K +def fast_rope_embedding(Q, K, cos, sin, rope_embedding_indices = None): + if rope_embedding_indices is not None: + Q_out, K_out = Fast_RoPE_Embedding_QK.apply( + Q, K, cos, sin, rope_embedding_indices + ) + else: + Q_out = Fast_RoPE_Embedding.apply( + Q.transpose(1, 2).contiguous(), cos, sin + ).transpose(1, 2) + K_out = Fast_RoPE_Embedding.apply( + K.transpose(1, 2).contiguous(), cos, sin + ).transpose(1, 2) + if DEVICE_COUNT > 1: + torch_device_stream(Q.device).synchronize() + return Q_out, K_out + + +class Fast_RoPE_Embedding_QK(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, K, cos, sin, rope_indices): + has_indices = rope_indices is not None + cos, sin = cos.squeeze(), sin.squeeze() + + batch, n_heads_Q, seq_len, head_dim = Q.shape + _, n_heads_K, _, _ = K.shape + + Q_out = Q.clone() + K_out = K.clone() + + if has_indices: + rope_ptr = rope_indices.reshape(-1).to(dtype = torch.int32, device = Q.device) + else: + rope_ptr = cos.new_empty(1, dtype = torch.int32) + + BLOCK_SIZE, num_warps = calculate_settings(head_dim) + + Q_batch_stride, Q_head_stride, Q_seq_stride = ( + Q_out.stride(0), + Q_out.stride(1), + Q_out.stride(2), + ) + K_batch_stride, K_head_stride, K_seq_stride = ( + K_out.stride(0), + K_out.stride(1), + K_out.stride(2), + ) + + with torch_gpu_device(Q.device): + _rope_embedding_QK[(batch * seq_len, n_heads_Q)]( + Q_out, + Q_batch_stride, + Q_head_stride, + Q_seq_stride, + K_out, + K_batch_stride, + K_head_stride, + K_seq_stride, + cos, + cos.stride(0), + sin, + sin.stride(0), + rope_ptr, + seq_len, + head_dim = head_dim, + n_heads_K = n_heads_K, + BACKWARD_PASS = False, + HAS_ROPE_INDICES = has_indices, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + + ctx.block_size = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.has_indices = has_indices + ctx.cos = cos + ctx.sin = sin + ctx.rope_indices = rope_ptr if has_indices else None + ctx.seq_len = seq_len + ctx.n_heads_Q = n_heads_Q + ctx.n_heads_K = n_heads_K + + return ( + Q_out, + K_out, + ) + + @staticmethod + def backward(ctx, dQ, dK): + batch, n_heads_Q, seq_len, head_dim = dQ.shape + _, n_heads_K, _, _ = dK.shape + + dQ_out = dQ.clone() + dK_out = dK.clone() + + rope_ptr = ( + ctx.rope_indices + if ctx.has_indices + else ctx.cos.new_empty(1, dtype = torch.int32) + ) + + Q_batch_stride, Q_head_stride, Q_seq_stride = ( + dQ_out.stride(0), + dQ_out.stride(1), + dQ_out.stride(2), + ) + K_batch_stride, K_head_stride, K_seq_stride = ( + dK_out.stride(0), + dK_out.stride(1), + dK_out.stride(2), + ) + + with torch_gpu_device(dQ.device): + _rope_embedding_QK[(batch * ctx.seq_len, ctx.n_heads_Q)]( + dQ_out, + Q_batch_stride, + Q_head_stride, + Q_seq_stride, + dK_out, + K_batch_stride, + K_head_stride, + K_seq_stride, + ctx.cos, + ctx.cos.stride(0), + ctx.sin, + ctx.sin.stride(0), + rope_ptr, + ctx.seq_len, + head_dim = head_dim, + n_heads_K = ctx.n_heads_K, + BACKWARD_PASS = True, + HAS_ROPE_INDICES = ctx.has_indices, + BLOCK_SIZE = ctx.block_size, + num_warps = ctx.num_warps, + ) + + return (dQ_out, dK_out, None, None, None) class Slow_RoPE_Embedding(torch.autograd.Function): @@ -206,8 +423,8 @@ def forward(ctx, Q, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] # Q * cos + rotate_half(Q) * sin half = Q.shape[-1] // 2 diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 54188305a..1d8858bd8 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -16,6 +16,13 @@ from ._utils import __version__ from unsloth_zoo.hf_utils import dtype_from_config from unsloth_zoo.utils import _get_dtype +from ..utils.packing import get_packed_info_from_kwargs +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, +) try: from transformers.models.cohere.modeling_cohere import ( @@ -104,6 +111,7 @@ def CohereAttention_fast_forward( Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device) if self.use_qk_norm: Q = fast_layernorm_compiled(self.q_norm, Q) K = fast_layernorm_compiled(self.k_norm, K) @@ -112,12 +120,17 @@ def CohereAttention_fast_forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = position_embeddings - if position_ids is None: - Q, K = fast_rope_embedding(Q, K, cos, sin) + # Extend RoPE dynamically to fit in VRAM + if position_embeddings: + cos, sin = position_embeddings else: - cos, sin = cos[position_ids], sin[position_ids] - Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index) + + rope_position_ids = ( + position_ids if position_ids is not None else kwargs.get("position_ids") + ) + # Useful for LongRoPE + Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -125,54 +138,33 @@ def CohereAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None: - # Xformers memory efficient attention - # Also has Flash Attention v2 dispatching - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - - # Group query attention - if n_groups != 1: - K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - if hidden_states.requires_grad: - K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) - V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - else: - Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - A = xformers_attention(Q, K, V, attn_bias = causal_mask) - A = A.view(bsz, q_len, n_heads, head_dim) - - elif HAS_FLASH_ATTENTION and attention_mask is None: - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - A = flash_attn_func(Q, K, V, causal = True) - else: - # Grouped query attention - if n_groups != 1: - K = K[:, :, None, :, :].expand( - bsz, n_kv_heads, n_groups, kv_seq_len, head_dim - ) - V = V[:, :, None, :, :].expand( - bsz, n_kv_heads, n_groups, kv_seq_len, head_dim - ) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention( - Q, K, V, attn_mask = attention_mask, is_causal = False - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + use_varlen = seq_info is not None and past_key_value is None + backend = select_attention_backend(use_varlen) + attention_config = AttentionConfig( + backend = backend, + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = {"causal": True}, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "causal": True, + "softmax_scale": getattr(self, "softmax_scale", None), + }, + ) + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + ) + + A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V) + attn_output = A.reshape(bsz, q_len, n_heads * head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -215,6 +207,7 @@ def CohereDecoderLayer_fast_forward( output_attentions = output_attentions, use_cache = use_cache, padding_mask = padding_mask, + **kwargs, ) # Fully Connected @@ -234,6 +227,7 @@ def CohereDecoderLayer_fast_forward( output_attentions = output_attentions, use_cache = use_cache, padding_mask = padding_mask, + **kwargs, ) # Fully Connected diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 70223a3c2..384edd454 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -17,6 +17,14 @@ from ._utils import __version__ from unsloth_zoo.utils import Version, _get_dtype from unsloth_zoo.hf_utils import dtype_from_config +from ..utils.packing import get_packed_info_from_kwargs +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, + SDPA, +) from .llama import ( LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, @@ -98,13 +106,10 @@ def FalconH1Attention_fast_forward( assert n_kv_heads * n_groups == n_heads Q, K, V = self.apply_qkv(self, hidden_states) - Q = Q.view( - bsz, q_len, n_heads, head_dim - ) # .transpose(1, 2) # we will transpose after normalisation - K = K.view( - bsz, q_len, n_kv_heads, head_dim - ) # .transpose(1, 2) # we will transpose after normalisation + Q = Q.view(bsz, q_len, n_heads, head_dim) + K = K.view(bsz, q_len, n_kv_heads, head_dim) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, hidden_states.device) # Falcon H1 multiplies key states by a multiplier K = K * self.config.key_multiplier @@ -116,20 +121,19 @@ def FalconH1Attention_fast_forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + # Extend RoPE dynamically to fit in VRAM if position_embeddings: cos, sin = position_embeddings else: - # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - device_index = Q.device.index + cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index) - if position_ids is None: - # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) - else: - cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) - Q, K = fast_rope_embedding(Q, K, cos, sin) + rope_position_ids = ( + position_ids if position_ids is not None else kwargs.get("position_ids") + ) + # Useful for LongRoPE + Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -137,54 +141,45 @@ def FalconH1Attention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if not HAS_FLASH_ATTENTION and attention_mask is None: - # Xformers memory efficient attention - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - K_M = V_M = bsz * kv_seq_len - Q_M = bsz * q_len - - # Group query attention - K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - if hidden_states.requires_grad: - K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) - V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - else: - # Xformers does support the forward pass though - Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - - A = xformers_attention(Q, K, V, attn_bias = causal_mask) - A = A.view(bsz, q_len, n_heads, head_dim) - - elif HAS_FLASH_ATTENTION and attention_mask is None: - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - sw = kv_seq_len - window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) - A = flash_attn_func(Q, K, V, causal = True, window_size = window) - else: - # Grouped query attention - # if n_groups != 1: - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - # pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention( - Q, K, V, attn_mask = attention_mask, is_causal = False - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + window = (-1, -1) + use_varlen = ( + attention_mask is None + and seq_info is not None + and past_key_value is None + and window == (-1, -1) + ) + + backend = ( + SDPA if attention_mask is not None else select_attention_backend(use_varlen) + ) + attention_config = AttentionConfig( + backend = backend, + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = { + "causal": True, + "window_size": (kv_seq_len, kv_seq_len), + }, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "softmax_scale": None, + "causal": True, + }, + sdpa_kwargs = {} if attention_mask is None else {"attn_mask": attention_mask}, + ) + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + ) + + A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V) attn_output = A.reshape(bsz, q_len, n_heads * head_dim) attn_output = self.apply_o(self, attn_output) @@ -442,6 +437,7 @@ def FalconH1DecoderLayer_fast_forward( use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) attention_hidden_states = attention_hidden_states * self.attn_out_multiplier @@ -486,6 +482,7 @@ def FalconH1DecoderLayer_fast_forward( use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) attention_hidden_states = attention_hidden_states * self.attn_out_multiplier diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index cf5bdd8eb..bb92ef6e6 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -16,6 +16,11 @@ from ._utils import __version__ from unsloth_zoo.utils import _get_dtype from unsloth_zoo.hf_utils import dtype_from_config +from ..utils.packing import ( + build_sdpa_packed_attention_mask, + build_xformers_block_causal_mask, + get_packed_info_from_kwargs, +) import math try: @@ -110,6 +115,7 @@ def GemmaDecoderLayer_fast_forward( output_attentions = output_attentions, use_cache = use_cache, padding_mask = padding_mask, + **kwargs, ) hidden_states += residual @@ -134,6 +140,7 @@ def GemmaDecoderLayer_fast_forward( output_attentions = output_attentions, use_cache = use_cache, padding_mask = padding_mask, + **kwargs, ) hidden_states = residual + hidden_states diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index de70f7873..98928b122 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -16,6 +16,14 @@ from ._utils import __version__ from unsloth_zoo.utils import _get_dtype from unsloth_zoo.hf_utils import dtype_from_config +from ..utils.packing import get_packed_info_from_kwargs +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, + SDPA, +) from .gemma import ( GemmaFixedRotaryEmbedding, GemmaFixedLinearScalingRotaryEmbedding, @@ -98,19 +106,25 @@ def Gemma2Attention_fast_forward( Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device) kv_seq_len = K.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] device_index = Q.device.index - if position_ids is None: - cos = self.rotary_emb.multi_gpu_cos_cached[device_index] - sin = self.rotary_emb.multi_gpu_sin_cached[device_index] - Q, K = fast_rope_embedding(Q, K, cos, sin) + cos = self.rotary_emb.multi_gpu_cos_cached[device_index] + sin = self.rotary_emb.multi_gpu_sin_cached[device_index] + + rope_position_ids = ( + position_ids if position_ids is not None else kwargs.get("position_ids") + ) + if rope_position_ids is not None: + # Useful for LongRoPE + cos_var, sin_var = self.rotary_emb.get_cached(kv_seq_len, device_index) + Q, K = fast_rope_embedding(Q, K, cos_var, sin_var, rope_position_ids) else: - cos, sin = self.rotary_emb.get_cached(kv_seq_len, device_index) - Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -118,32 +132,68 @@ def Gemma2Attention_fast_forward( past_key_value = (K, V) if use_cache else None # Only enable if the attention_mask is True - has_sliding_window = type(causal_mask) is bool and causal_mask is True - if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None: + use_sliding_window = kwargs.get("use_sliding_window") + has_sliding_window = ( + use_sliding_window + if use_sliding_window is not None + else isinstance(causal_mask, bool) and causal_mask is True + ) + + use_flash = HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None + + if use_flash: window = (-1, -1) + sliding_window = getattr(self.config, "sliding_window", None) if has_sliding_window: - sw = getattr(self.config, "sliding_window", None) - sw = kv_seq_len if (sw is None or sw == "null") else sw - window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) + sliding_window = ( + sliding_window if sliding_window is not None else kv_seq_len + ) + window = ( + (-1, -1) + if kv_seq_len <= sliding_window + else (sliding_window, sliding_window) + ) - # FA uses 1 / sqrt for softmax_scale! if not hasattr(self, "_flash_attention_softmax_scale"): self._flash_attention_softmax_scale = 1.0 / ( self.config.query_pre_attn_scalar**0.5 ) - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - A = flash_attn_func( - Q, - K, - V, - causal = True, - softcap = self.config.attn_logit_softcapping, - softmax_scale = self._flash_attention_softmax_scale, - window_size = window, + use_varlen = seq_info is not None and past_key_value is None + + attention_config = AttentionConfig( + backend = select_attention_backend(use_varlen), + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = { + "causal": True, + "softcap": self.config.attn_logit_softcapping, + "softmax_scale": self._flash_attention_softmax_scale, + "window_size": window, + }, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "softmax_scale": self._flash_attention_softmax_scale, + "causal": True, + "softcap": self.config.attn_logit_softcapping, + "window_size": window, + }, + ) + + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + sliding_window = sliding_window, ) + + A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V) A = A.reshape(bsz, q_len, n_heads * head_dim) else: fx = ( @@ -192,6 +242,7 @@ def Gemma2DecoderLayer_fast_forward( use_cache = use_cache, padding_mask = padding_mask, _flag_for_generation = self._flag_for_generation, + **kwargs, ) hidden_states = fast_rms_layernorm_inference_gemma( self.post_attention_layernorm, hidden_states, out_weight @@ -222,6 +273,7 @@ def Gemma2DecoderLayer_fast_forward( output_attentions = output_attentions, use_cache = use_cache, padding_mask = padding_mask, + **kwargs, ) hidden_states = fast_rms_layernorm( self.post_attention_layernorm, hidden_states, gemma = True diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 0a816e399..e513927a6 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -17,6 +17,14 @@ from ._utils import __version__ from unsloth_zoo.utils import _get_dtype from unsloth_zoo.hf_utils import dtype_from_config +from ..utils.packing import get_packed_info_from_kwargs +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, + SDPA, +) from .llama import ( LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, @@ -96,6 +104,7 @@ def GraniteAttention_fast_forward( Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device) kv_seq_len = K.shape[-2] if past_key_value is not None: @@ -103,10 +112,14 @@ def GraniteAttention_fast_forward( assert position_embeddings is not None cos, sin = position_embeddings - if position_ids is None: - Q, K = fast_rope_embedding(Q, K, cos, sin) + rope_position_ids = ( + position_ids if position_ids is not None else kwargs.get("position_ids") + ) + if rope_position_ids is not None: + # Useful for LongRoPE + Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids) else: - Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -114,69 +127,59 @@ def GraniteAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None: - # Xformers memory efficient attention - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - K_M = V_M = bsz * kv_seq_len - Q_M = bsz * q_len - - # Group query attention - K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - if hidden_states.requires_grad: - K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) - V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - else: - # Xformers does support the forward pass though - Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - - A = xformers_attention( - Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p - ) - A = A.view(bsz, q_len, n_heads, head_dim) - - elif HAS_FLASH_ATTENTION and attention_mask is None: - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - window = (kv_seq_len, kv_seq_len) - A = flash_attn_func( - Q, - K, - V, - causal = True, - window_size = window, - softmax_scale = self.scaling, - dropout_p = dropout_p, - ) - else: - # Grouped query attention - # if n_groups != 1: - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - # pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention( - Q, - K, - V, - attn_mask = attention_mask, - scale = self.scaling, - is_causal = False, - dropout_p = dropout_p, - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + use_varlen = ( + attention_mask is None and seq_info is not None and past_key_value is None + ) + + backend = ( + SDPA if attention_mask is not None else select_attention_backend(use_varlen) + ) + + window = (kv_seq_len, kv_seq_len) + softmax_scale = getattr(self, "scaling", None) + attention_config = AttentionConfig( + backend = backend, + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = { + "causal": True, + "softmax_scale": softmax_scale, + "dropout_p": dropout_p, + "window_size": window, + }, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "softmax_scale": softmax_scale, + "causal": True, + }, + sdpa_kwargs = { + k: v + for k, v in { + "attn_mask": attention_mask, + "scale": softmax_scale, + "dropout_p": dropout_p, + }.items() + if v is not None + }, + xformers_kwargs = { + "scale": softmax_scale, + "p": dropout_p, + }, + ) + + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + ) + + A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V) attn_output = A.reshape(bsz, q_len, n_heads * head_dim) attn_output = self.apply_o(self, attn_output) @@ -222,6 +225,7 @@ def GraniteDecoderLayer_fast_forward( padding_mask = padding_mask, position_embeddings = position_embeddings, _flag_for_generation = self._flag_for_generation, + **kwargs, ) hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) @@ -245,6 +249,7 @@ def GraniteDecoderLayer_fast_forward( use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3af9cba66..5b54526f8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -16,12 +16,23 @@ import gc import math import functools -from typing import Any, Dict, Optional, Tuple, List, Union +from typing import Optional, Tuple, List, Union + from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__, importlib_version from ._utils import move_to_device from ._utils import _prepare_model_for_qat +from ..utils.packing import ( + get_packed_info_from_kwargs, + mask_packed_sequence_boundaries, +) +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, +) from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype @@ -58,9 +69,6 @@ ) from ..kernels import * from ..tokenizer_utils import * - -if HAS_FLASH_ATTENTION: - from flash_attn import flash_attn_func from .vision import FastBaseModel # Final patching code @@ -561,6 +569,7 @@ def LlamaAttention_fast_forward( Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device) kv_seq_len = K.shape[-2] if past_key_value is not None: @@ -569,23 +578,20 @@ def LlamaAttention_fast_forward( if position_embeddings: cos, sin = position_embeddings else: - # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - - # if position_ids is None: - # # Useful for LongRoPE - # cos, sin = rotary_emb.get_cached(kv_seq_len, device = Q.device) - # else: - # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index) + rope_position_ids = position_ids + if rope_position_ids is None and seq_info is not None: + rope_position_ids = kwargs.get("position_ids") + # Q, K = ( # fast_rope_embedding(Q, K, cos, sin) - # if position_ids is None - # else inplace_rope_embedding(Q, K, cos, sin, position_ids) + # if rope_position_ids is None + # else inplace_rope_embedding(Q, K, cos, sin, rope_position_ids) # ) - Q, K = fast_rope_embedding(Q, K, cos, sin) + Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -593,76 +599,28 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None: - # Xformers memory efficient attention - # Also has Flash Attention v2 dispatching - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - - # Group query attention - if n_groups != 1: - K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - if hidden_states.requires_grad: - K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) - V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - else: - Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - A = xformers_attention(Q, K, V, attn_bias = causal_mask) - A = A.view(bsz, q_len, n_heads, head_dim) - - elif HAS_FLASH_ATTENTION and attention_mask is None: - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - A = flash_attn_func(Q, K, V, causal = True) - else: - # when qlen==vlen and attn_mask is None, we should use causal attention - Q_len = Q.shape[-2] - K_len = K.shape[-2] - if attention_mask is None and Q_len == K_len: - is_causal = True - else: - is_causal = False - # Grouped query attention - if SDPA_HAS_GQA: - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention( - Q, - K, - V, - attn_mask = attention_mask, - is_causal = is_causal, - enable_gqa = n_groups != 1, - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2) # .contiguous() - else: - if n_groups != 1: - K = K[:, :, None, :, :].expand( - bsz, n_kv_heads, n_groups, kv_seq_len, head_dim - ) - V = V[:, :, None, :, :].expand( - bsz, n_kv_heads, n_groups, kv_seq_len, head_dim - ) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention( - Q, K, V, attn_mask = attention_mask, is_causal = is_causal - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() - pass + use_varlen = seq_info is not None and past_key_value is None + backend = select_attention_backend(use_varlen) + config = AttentionConfig( + backend = backend, + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = {"causal": True}, + flash_varlen_kwargs = {"dropout_p": 0.0, "causal": True}, + ) + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + ) + + A = run_attention(config = config, context = context, Q = Q, K = K, V = V) attn_output = A.reshape(bsz, q_len, n_heads * head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -712,6 +670,7 @@ def LlamaDecoderLayer_fast_forward( use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) hidden_states += residual @@ -735,6 +694,7 @@ def LlamaDecoderLayer_fast_forward( use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -812,8 +772,11 @@ def LlamaModel_fast_forward( seq_length_with_past = seq_length - # Fix out of bounds tokenization - if hasattr(self, "max_seq_length"): + # Fix out of bounds tokenization unless we were given packed metadata + allow_overlength = getattr(self, "_unsloth_allow_packed_overlength", False) or ( + "packed_seq_lengths" in kwargs + ) + if hasattr(self, "max_seq_length") and not allow_overlength: if seq_length > self.max_seq_length: shape = input_ids.shape if input_ids is not None else inputs_embeds.shape logger.warning_once( @@ -1065,10 +1028,12 @@ def LlamaModel_fast_forward( mask = causal_mask if IS_GEMMA2: - if idx % 2 == 0: + use_sliding_window = idx % 2 == 0 + if use_sliding_window: mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask else: mask = self.GA_mask if use_static_mask else dynamic_GA_mask + kwargs["use_sliding_window"] = use_sliding_window if gradient_checkpointing and not isinstance( decoder_layer, GradientCheckpointingLayer @@ -1082,6 +1047,7 @@ def custom_forward(*inputs): output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) return custom_forward @@ -1108,6 +1074,7 @@ def custom_forward(*inputs): use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1299,6 +1266,7 @@ def _CausalLM_fast_forward( past_key_values, position_ids = position_ids, attention_mask = attention_mask, + **kwargs, ) else: causal_mask = ( @@ -1331,6 +1299,7 @@ def _CausalLM_fast_forward( output_attentions = output_attentions, output_hidden_states = output_hidden_states, return_dict = return_dict, + **kwargs, ) hidden_states = outputs[0] @@ -1439,6 +1408,10 @@ def _CausalLM_fast_forward( shift_labels = torch.empty_like(labels) shift_labels[..., :-1] = labels[..., 1:] shift_labels[..., -1] = -100 + mask_packed_sequence_boundaries( + shift_labels, + kwargs.get("packed_seq_lengths"), + ) # shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) n_items = kwargs.get("num_items_in_batch", None) if n_items is None: @@ -1996,8 +1969,8 @@ def unsloth_fast_generate( > self.config.max_position_embeddings ): raise ValueError( - f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n' - 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' + f"Unsloth: input length {kwargs['input_ids'].shape[-1]} + max_new_tokens {kwargs['max_new_tokens']} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n" + "You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`." ) # Must patch accelerate for Xformers @@ -2235,8 +2208,6 @@ def from_pretrained( model_config.model_name = model_name model_max_seq_length = model_config.max_position_embeddings - verify_fp8_support_if_applicable(model_config) - # Check if RoPE Scaling is even allowed model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__] IS_FALCON_H1 = model_config.model_type.startswith("falcon_h1") diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py index 631f614cb..d587d0016 100644 --- a/unsloth/models/loader_utils.py +++ b/unsloth/models/loader_utils.py @@ -13,6 +13,9 @@ # limitations under the License. from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit +from ..device_type import DEVICE_TYPE_TORCH +import os +import torch # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! from packaging.version import Version @@ -21,6 +24,9 @@ transformers_version = Version(transformers_version) SUPPORTS_FOURBIT = transformers_version >= Version("4.37") +LOCAL_RANK_KEYS = ("LOCAL_RANK", "RANK") +WORLD_SIZE_KEYS = ("WORLD_SIZE",) + BAD_MAPPINGS = { "unsloth/Qwen3-32B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-32B-bnb-4bit".lower(), # 32B dynamic quant is way too big "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B".lower(), # HF loads MoEs too slowly @@ -30,6 +36,50 @@ } +def _get_env_int(keys): + for key in keys: + value = os.environ.get(key) + if value is None: + continue + try: + return int(value) + except ValueError: + continue + return None + + +def _infer_distributed_ranks(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + try: + return torch.distributed.get_rank(), torch.distributed.get_world_size() + except Exception: + pass + return _get_env_int(LOCAL_RANK_KEYS), _get_env_int(WORLD_SIZE_KEYS) + + +def is_distributed(): + rank, world_size = _infer_distributed_ranks() + return (world_size or 1) > 1 or (rank is not None and rank > 0) + + +def prepare_device_map(): + rank, world_size = _infer_distributed_ranks() + distributed = (world_size or 1) > 1 or (rank is not None and rank > 0) + if not distributed: + return None, False + + local_rank = 0 if rank is None else rank + device_map = {"": f"{DEVICE_TYPE_TORCH}:{local_rank}"} + try: + if DEVICE_TYPE_TORCH == "cuda": + torch.cuda.set_device(local_rank) + elif DEVICE_TYPE_TORCH == "xpu" and hasattr(torch, "xpu"): + torch.xpu.set_device(local_rank) + except Exception: + pass + return device_map, True + + def __get_model_name( model_name, load_in_4bit = True, diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 1945e6a74..0f7ea3c0b 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -17,6 +17,16 @@ from ._utils import __version__ from unsloth_zoo.utils import _get_dtype from unsloth_zoo.hf_utils import dtype_from_config +from ..utils.packing import ( + get_packed_info_from_kwargs, + mask_packed_sequence_boundaries, +) +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, +) from .llama import ( LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, @@ -76,6 +86,7 @@ def MistralAttention_fast_forward( Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, Q.device) kv_seq_len = K.shape[-2] if past_key_value is not None: @@ -83,12 +94,13 @@ def MistralAttention_fast_forward( # Extend RoPE dynamically to fit in VRAM self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index) - if position_ids is None: - Q, K = fast_rope_embedding(Q, K, cos, sin) - else: - Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + + rope_position_ids = ( + position_ids if position_ids is not None else kwargs.get("position_ids") + ) + # Useful for LongRoPE + Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -96,68 +108,38 @@ def MistralAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None: - # Xformers memory efficient attention - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - K_M = V_M = bsz * kv_seq_len - Q_M = bsz * q_len - - has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask) - - # Group query attention - K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - if hidden_states.requires_grad: - K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) - V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - - if has_swa: - Q = Q.view(1, Q_M, n_heads, head_dim) - K = K.view(1, K_M, n_heads, head_dim) - V = V.view(1, V_M, n_heads, head_dim) - else: - # Xformers does support the forward pass though - Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - - if has_swa: - Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim) - K = K.view(1, K_M, n_kv_heads, n_groups, head_dim) - V = V.view(1, V_M, n_kv_heads, n_groups, head_dim) - - A = xformers_attention(Q, K, V, attn_bias = causal_mask) - A = A.view(bsz, q_len, n_heads, head_dim) - - elif HAS_FLASH_ATTENTION and attention_mask is None: - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - sw = getattr(self.config, "sliding_window", None) - sw = kv_seq_len if (sw is None or sw == "null") else sw - window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) - A = flash_attn_func(Q, K, V, causal = True, window_size = window) - else: - # Grouped query attention - # if n_groups != 1: - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - # pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention( - Q, K, V, attn_mask = attention_mask, is_causal = False - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + sw_cfg = getattr(self.config, "sliding_window", None) + sw = kv_seq_len if (sw_cfg is None or sw_cfg == "null") else sw_cfg + window_size = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) + use_varlen = ( + seq_info is not None and past_key_value is None and window_size == (-1, -1) + ) + backend = select_attention_backend(use_varlen) + attention_config = AttentionConfig( + backend = backend, + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = {"causal": True, "window_size": window_size}, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "causal": True, + "softmax_scale": getattr(self, "softmax_scale", None), + }, + ) + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + ) + + A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V) attn_output = A.reshape(bsz, q_len, n_heads * head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -376,6 +358,10 @@ def MistralForCausalLM_fast_forward( shift_labels = torch.empty_like(labels) shift_labels[..., :-1] = labels[..., 1:] shift_labels[..., -1] = -100 + mask_packed_sequence_boundaries( + shift_labels, + kwargs.get("packed_seq_lengths"), + ) loss = fast_cross_entropy_loss( logits = shift_logits, labels = shift_labels, diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 3d905fe06..d9f54f267 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -16,6 +16,13 @@ import os from ._utils import __version__ from unsloth_zoo.utils import Version, _get_dtype +from ..utils.packing import get_packed_info_from_kwargs +from ..utils.attention_dispatch import ( + AttentionConfig, + AttentionContext, + run_attention, + select_attention_backend, +) from .llama import ( LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, @@ -95,6 +102,7 @@ def Qwen3Attention_fast_forward( bsz, q_len, n_kv_heads, head_dim ) # .transpose(1, 2) # we will transpose after normalisation V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + seq_info = get_packed_info_from_kwargs(kwargs, bsz * q_len, hidden_states.device) # Qwen3 has QKNorm. This seems to be the only difference from Qwen2. # Note that using fast_layernorm_compiled causes issues as the dimensions don't match up. @@ -110,20 +118,19 @@ def Qwen3Attention_fast_forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + # Extend RoPE dynamically to fit in VRAM if position_embeddings: cos, sin = position_embeddings else: - # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - device_index = Q.device.index + cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index) - if position_ids is None: - # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) - else: - cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) - Q, K = fast_rope_embedding(Q, K, cos, sin) + rope_position_ids = ( + position_ids if position_ids is not None else kwargs.get("position_ids") + ) + # Useful for LongRoPE + Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -131,74 +138,32 @@ def Qwen3Attention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None: - # Xformers memory efficient attention - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - K_M = V_M = bsz * kv_seq_len - Q_M = bsz * q_len - - has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask) - - # Group query attention - K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) - K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) - if hidden_states.requires_grad: - K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) - V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - - if has_swa: - Q = Q.view(1, Q_M, n_heads, head_dim) - K = K.view(1, K_M, n_heads, head_dim) - V = V.view(1, V_M, n_heads, head_dim) - else: - # Xformers does support the forward pass though - Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - - if has_swa: - Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim) - K = K.view(1, K_M, n_kv_heads, n_groups, head_dim) - V = V.view(1, V_M, n_kv_heads, n_groups, head_dim) - - A = xformers_attention(Q, K, V, attn_bias = causal_mask) - A = A.view(bsz, q_len, n_heads, head_dim) - - elif HAS_FLASH_ATTENTION and attention_mask is None: - Q = Q.transpose(1, 2) - K = K.transpose(1, 2) - V = V.transpose(1, 2) - sw = kv_seq_len - window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) - A = flash_attn_func(Q, K, V, causal = True, window_size = window) - else: - # Grouped query attention - # if n_groups != 1: - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - # pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - # when qlen==vlen and attn_mask is None, we should use causal attention - Q_len = Q.shape[-2] - K_len = K.shape[-2] - if attention_mask is None and Q_len == K_len: - is_causal = True - else: - is_causal = False + use_varlen = seq_info is not None and past_key_value is None + backend = select_attention_backend(use_varlen) + attention_config = AttentionConfig( + backend = backend, + n_kv_heads = n_kv_heads, + n_groups = n_groups, + flash_dense_kwargs = {"causal": True}, + flash_varlen_kwargs = { + "dropout_p": 0.0, + "causal": True, + "softmax_scale": getattr(self, "softmax_scale", None), + }, + ) + context = AttentionContext( + bsz = bsz, + q_len = q_len, + kv_seq_len = kv_seq_len, + n_heads = n_heads, + head_dim = head_dim, + requires_grad = hidden_states.requires_grad, + seq_info = seq_info, + attention_mask = attention_mask, + causal_mask = causal_mask, + ) - A = scaled_dot_product_attention( - Q, K, V, attn_mask = attention_mask, is_causal = is_causal - ) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V) attn_output = A.reshape(bsz, q_len, n_heads * head_dim) attn_output = self.apply_o(self, attn_output) diff --git a/unsloth/utils/__init__.py b/unsloth/utils/__init__.py index e69de29bb..7870a0879 100644 --- a/unsloth/utils/__init__.py +++ b/unsloth/utils/__init__.py @@ -0,0 +1,24 @@ +from .packing import configure_sample_packing, enable_sample_packing +from .attention_dispatch import ( + AttentionConfig, + AttentionContext, + FLASH_DENSE, + FLASH_VARLEN, + SDPA, + XFORMERS, + run_attention, + select_attention_backend, +) + +__all__ = [ + "configure_sample_packing", + "enable_sample_packing", + "AttentionConfig", + "AttentionContext", + "FLASH_VARLEN", + "FLASH_DENSE", + "XFORMERS", + "SDPA", + "run_attention", + "select_attention_backend", +] diff --git a/unsloth/utils/attention_dispatch.py b/unsloth/utils/attention_dispatch.py new file mode 100644 index 000000000..45aef551e --- /dev/null +++ b/unsloth/utils/attention_dispatch.py @@ -0,0 +1,271 @@ +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +"""Shared helpers for attention backend selection and execution.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple + +from torch import Tensor +from torch.nn.functional import scaled_dot_product_attention + +from ..models._utils import * +from ..utils.packing import ( + build_sdpa_packed_attention_mask, + build_xformers_block_causal_mask, +) + +if HAS_FLASH_ATTENTION: + from flash_attn import flash_attn_func, flash_attn_varlen_func +HAS_XFORMERS = xformers is not None +BlockDiagonalCausalMask = None +if HAS_XFORMERS: + BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask +SDPA_HAS_GQA = "enable_gqa" in (scaled_dot_product_attention.__doc__ or "") + +FLASH_VARLEN = "flash_varlen" +FLASH_DENSE = "flash_dense" +XFORMERS = "xformers" +SDPA = "sdpa" + + +XFORMERS_BLOCK_DIAG_CLS = ( + xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None +) + + +@dataclass +class AttentionConfig: + """ + Per-layer attention metadata. + + NOTE(djsaunde): I had originally intended this to be populated once per layer, but + we're currently constructing it on every forward pass since it can possibly be + invalid from one forward pass to the next (e.g., switching from training to + inference). For now, I'm keeping separate from AttentionContext for the sake of + better grouping of params. + """ + + backend: str + n_kv_heads: int + n_groups: int + flash_dense_kwargs: Optional[dict[str, Any]] = None + flash_varlen_kwargs: Optional[dict[str, Any]] = None + sdpa_kwargs: Optional[dict[str, Any]] = None + xformers_kwargs: Optional[dict[str, Any]] = None + + +@dataclass +class AttentionContext: + """Per-call info required to run attention.""" + + bsz: int + q_len: int + kv_seq_len: int + n_heads: int + head_dim: int + requires_grad: bool + seq_info: Optional[Tuple[Tensor, Tensor, int]] + attention_mask: Optional[Tensor] + causal_mask: Optional[Any] + sliding_window: Optional[int] = None + + +def select_attention_backend(use_varlen: bool = False) -> str: + """Return attention backend based on availability / priority order.""" + + if HAS_FLASH_ATTENTION: + if use_varlen: + return FLASH_VARLEN + else: + return FLASH_DENSE + if HAS_XFORMERS: + return XFORMERS + return SDPA + + +def run_attention( + *, + config: AttentionConfig, + context: AttentionContext, + Q: Tensor, + K: Tensor, + V: Tensor, +) -> Tensor: + """Run attention using config / context info.""" + + backend = config.backend + if backend == FLASH_VARLEN and context.seq_info is None: + backend = FLASH_DENSE if HAS_FLASH_ATTENTION else SDPA + flash_dense_kwargs = config.flash_dense_kwargs or {} + flash_varlen_kwargs = config.flash_varlen_kwargs or {} + sdpa_kwargs = config.sdpa_kwargs or {} + xformers_kwargs = config.xformers_kwargs or {} + + bsz = context.bsz + n_heads = context.n_heads + q_len = context.q_len + head_dim = context.head_dim + kv_seq_len = context.kv_seq_len + requires_grad = context.requires_grad + sliding_window = context.sliding_window + + if backend == FLASH_VARLEN: + Q_f = Q.transpose(1, 2).reshape(bsz * q_len, n_heads, head_dim) + K_f = K.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim) + V_f = V.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim) + _, cu_seqlens, max_seqlen = context.seq_info + return flash_attn_varlen_func( + Q_f, + K_f, + V_f, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + **flash_varlen_kwargs, + ).view(bsz, q_len, n_heads, head_dim) + elif backend == FLASH_DENSE: + Q_t = Q.transpose(1, 2) + K_t = K.transpose(1, 2) + V_t = V.transpose(1, 2) + return flash_attn_func(Q_t, K_t, V_t, **flash_dense_kwargs).reshape( + bsz, q_len, n_heads, head_dim + ) + elif backend == XFORMERS: + attn_bias = build_xformers_block_causal_mask( + context.seq_info, + sliding_window = sliding_window, + base_mask = context.causal_mask, + ) + + Q_t = Q.transpose(1, 2) + K_t = K.transpose(1, 2) + V_t = V.transpose(1, 2) + + K_mod = K_t + V_mod = V_t + Q_mod = Q_t + + if config.n_groups != 1: + K_mod = K_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim) + V_mod = V_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim) + K_mod = K_mod.expand( + bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim + ) + V_mod = V_mod.expand( + bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim + ) + + if requires_grad: + K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim) + V_mod = V_mod.reshape(bsz, kv_seq_len, n_heads, head_dim) + else: + Q_mod = Q_t.view( + bsz, q_len, config.n_kv_heads, config.n_groups, head_dim + ) + + has_block = XFORMERS_BLOCK_DIAG_CLS is not None and isinstance( + attn_bias, XFORMERS_BLOCK_DIAG_CLS + ) + + if config.n_groups != 1 and not requires_grad and has_block: + Q_mod = Q_mod.view( + 1, bsz * q_len, config.n_kv_heads, config.n_groups, head_dim + ) + K_mod = K_mod.view( + 1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim + ) + V_mod = V_mod.view( + 1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim + ) + elif config.n_groups != 1 and requires_grad and has_block: + Q_mod = Q_mod.view(1, bsz * q_len, n_heads, head_dim) + K_mod = K_mod.view(1, bsz * kv_seq_len, n_heads, head_dim) + V_mod = V_mod.view(1, bsz * kv_seq_len, n_heads, head_dim) + + out = xformers_attention( + Q_mod, + K_mod, + V_mod, + attn_bias = attn_bias, + **xformers_kwargs, + ) + + if config.n_groups != 1 and not requires_grad: + if has_block: + out = out.view(bsz, q_len, config.n_kv_heads, config.n_groups, head_dim) + else: + out = out.view(bsz, q_len, config.n_kv_heads, config.n_groups, head_dim) + out = out.reshape(bsz, q_len, n_heads, head_dim) + else: + if has_block: + out = out.view(bsz, q_len, n_heads, head_dim) + else: + out = out.view(bsz, q_len, n_heads, head_dim) + return out + else: + local_mask = context.attention_mask + is_causal_local = False + if context.seq_info is not None and local_mask is None: + local_mask = build_sdpa_packed_attention_mask( + context.seq_info, + dtype = Q.dtype, + device = Q.device, + sliding_window = sliding_window, + ) + else: + q_len_local = Q.shape[-2] + k_len_local = K.shape[-2] + is_causal_local = local_mask is None and q_len_local == k_len_local + + kwargs = dict(sdpa_kwargs) + kwargs.setdefault("attn_mask", local_mask) + kwargs.setdefault("is_causal", is_causal_local) + + if SDPA_HAS_GQA: + kwargs.setdefault("enable_gqa", config.n_groups != 1) + out = scaled_dot_product_attention(Q, K, V, **kwargs) + return out.transpose(1, 2) + + K_mod = K + V_mod = V + if config.n_groups != 1: + K_mod = K[:, :, None, :, :].expand( + bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim + ) + V_mod = V[:, :, None, :, :].expand( + bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim + ) + K_mod = K_mod.reshape(bsz, n_heads, kv_seq_len, head_dim) + V_mod = V_mod.reshape(bsz, n_heads, kv_seq_len, head_dim) + + out = scaled_dot_product_attention( + Q.contiguous(), + K_mod.contiguous(), + V_mod.contiguous(), + **kwargs, + ) + return out.transpose(1, 2).contiguous() + + +__all__ = [ + "AttentionConfig", + "AttentionContext", + "select_attention_backend", + "run_attention", +] diff --git a/unsloth/utils/packing.py b/unsloth/utils/packing.py new file mode 100644 index 000000000..26f8542f4 --- /dev/null +++ b/unsloth/utils/packing.py @@ -0,0 +1,253 @@ +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +"""Utilities for enabling packed (padding-free) batches across Unsloth.""" + +from __future__ import annotations + +import logging +from typing import Any, Iterable, Optional, Sequence, Tuple + +import torch + +try: + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask as _XFormersBlockMask, + ) +except Exception: + try: + from xformers.attn_bias import BlockDiagonalCausalMask as _XFormersBlockMask + except Exception: + _XFormersBlockMask = None + + +class _TrlPackingWarningFilter(logging.Filter): + _NEEDLES = ( + "Padding-free training is enabled, but the attention implementation is not set to 'flash_attention_2'", + "You are using packing, but the attention implementation is not set to 'flash_attention_2' or 'kernels-community/vllm-flash-attn3'", + ) + + def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - trivial + message = record.getMessage() + return not any(needle in message for needle in self._NEEDLES) + + +_TRL_FILTER_INSTALLED = False + + +def _ensure_trl_warning_filter(): + global _TRL_FILTER_INSTALLED + if _TRL_FILTER_INSTALLED: + return + logging.getLogger("trl.trainer.sft_trainer").addFilter(_TrlPackingWarningFilter()) + _TRL_FILTER_INSTALLED = True + + +def configure_sample_packing(config): + """Mutate an ``SFTConfig`` so TRL prepares packed batches.""" + _ensure_trl_warning_filter() + setattr(config, "packing", True) + setattr(config, "padding_free", True) + setattr(config, "remove_unused_columns", False) + + +def enable_sample_packing(model, trainer): + """Enable runtime support for packed batches on an existing trainer.""" + + def _mark_allow_overlength(module): + if hasattr(module, "max_seq_length"): + setattr(module, "_unsloth_allow_packed_overlength", True) + for child in module.children(): + _mark_allow_overlength(child) + + _mark_allow_overlength(model) + + collator = getattr(trainer, "data_collator", None) + if ( + collator is None + or not hasattr(collator, "torch_call") + or getattr(collator, "_unsloth_packing_wrapped", False) + ): + return + + if hasattr(collator, "return_position_ids"): + collator.return_position_ids = True + + original_torch_call = collator.torch_call + + def torch_call_with_lengths(examples: Sequence[dict]): + batch = original_torch_call(examples) + if examples and isinstance(examples[0], dict): + seq_lengths: list[int] = [] + for example in examples: + seq_lengths.extend(example["seq_lengths"]) + if seq_lengths: + batch["packed_seq_lengths"] = torch.tensor( + seq_lengths, dtype = torch.int32 + ) + return batch + + collator.torch_call = torch_call_with_lengths + collator._unsloth_packing_wrapped = True + + +def get_packed_info_from_kwargs( + kwargs: dict, + total_tokens: int, + device: torch.device, +) -> Optional[Tuple[torch.Tensor, torch.Tensor, int]]: + """Extract packed sequence information from attention kwargs.""" + + seq_lengths = kwargs.get("packed_seq_lengths") + if seq_lengths is None: + return None + + if isinstance(seq_lengths, torch.Tensor): + lengths = seq_lengths.to(device = device, dtype = torch.int32) + else: + lengths = torch.tensor(seq_lengths, device = device, dtype = torch.int32) + + if lengths.ndim > 1: + lengths = lengths.reshape(-1) + + if lengths.numel() == 0: + return None + + if int(lengths.sum().item()) != total_tokens: + return None + + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype = torch.int32, device = device), + torch.cumsum(lengths, dim = 0, dtype = torch.int32), + ] + ) + max_seqlen = int(lengths.max().item()) + return lengths, cu_seqlens, max_seqlen + + +def build_xformers_block_causal_mask( + seq_info: Optional[Tuple[torch.Tensor, torch.Tensor, int]], + *, + sliding_window: Optional[int] = None, + base_mask: Optional[Any] = None, +): + if _XFormersBlockMask is None: + return None + if seq_info is not None: + seq_lengths, _, _ = seq_info + lengths = seq_lengths.to("cpu", torch.int32).tolist() + if not lengths: + return None + mask = _XFormersBlockMask.from_seqlens(lengths) + else: + mask = base_mask + + if ( + sliding_window is not None + and sliding_window > 0 + and mask is not None + and hasattr(mask, "make_local_attention") + ): + mask = mask.make_local_attention(window_size = sliding_window) + return mask + + +def build_sdpa_packed_attention_mask( + seq_info: Tuple[torch.Tensor, torch.Tensor, int], + *, + dtype: torch.dtype, + device: torch.device, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + seq_lengths, _, _ = seq_info + total_tokens = int(seq_lengths.sum().item()) + mask = torch.full( + (total_tokens, total_tokens), + float("-inf"), + dtype = dtype, + device = device, + ) + offset = 0 + for length in seq_lengths.tolist(): + length = int(length) + if length <= 0: + continue + block = torch.zeros((length, length), dtype = dtype, device = device) + upper = torch.triu( + torch.ones((length, length), device = device), diagonal = 1 + ).bool() + block = block.masked_fill(upper, float("-inf")) + if ( + sliding_window is not None + and sliding_window > 0 + and length > sliding_window + ): + idx = torch.arange(length, device = device) + dist = idx.unsqueeze(1) - idx.unsqueeze(0) + window_mask = dist >= sliding_window + block = block.masked_fill(window_mask, float("-inf")) + mask[offset : offset + length, offset : offset + length] = block + offset += length + return mask.unsqueeze(0).unsqueeze(0) + + +def _normalize_packed_lengths( + seq_lengths: Any, + *, + device: torch.device, +) -> Optional[torch.Tensor]: + if seq_lengths is None: + return None + if isinstance(seq_lengths, torch.Tensor): + lengths = seq_lengths.to(device = device, dtype = torch.int64) + else: + lengths = torch.tensor(seq_lengths, device = device, dtype = torch.int64) + if lengths.ndim != 1: + lengths = lengths.reshape(-1) + if lengths.numel() == 0: + return None + return lengths + + +def mask_packed_sequence_boundaries( + shift_labels: torch.Tensor, + seq_lengths: Any, + *, + ignore_index: int = -100, +) -> bool: + """Mark final token of every packed sample so CE ignores boundary predictions.""" + + lengths = _normalize_packed_lengths(seq_lengths, device = shift_labels.device) + if lengths is None: + return False + + flat = shift_labels.reshape(-1) + total_tokens = flat.shape[0] + boundary_positions = torch.cumsum(lengths, dim = 0) - 1 + valid = boundary_positions < total_tokens + if not torch.all(valid): + boundary_positions = boundary_positions[valid] + if boundary_positions.numel() == 0: + return False + flat[boundary_positions] = ignore_index + return True + + +__all__ = [ + "configure_sample_packing", + "enable_sample_packing", + "mask_packed_sequence_boundaries", +]