Skip to content

Commit fdebcef

Browse files
committed
implement (sdpa, xformers, fa2) sample packing
1 parent 5314c21 commit fdebcef

File tree

12 files changed

+1087
-152
lines changed

12 files changed

+1087
-152
lines changed

tests/utils/test_packing.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from unsloth import FastLanguageModel
2+
from unsloth.utils.packing import configure_sample_packing, enable_sample_packing
3+
4+
from contextlib import ExitStack
5+
from types import SimpleNamespace
6+
from unittest.mock import patch
7+
8+
import pytest
9+
import torch
10+
from datasets import Dataset
11+
from trl import SFTConfig, SFTTrainer
12+
13+
14+
def _build_packed_training_setup(tmp_path, device):
15+
dtype = None
16+
if device.type == "cuda":
17+
if torch.cuda.is_bf16_supported():
18+
dtype = torch.bfloat16
19+
else:
20+
dtype = torch.float16
21+
22+
try:
23+
model, tokenizer = FastLanguageModel.from_pretrained(
24+
model_name="hf-internal-testing/tiny-random-LlamaForCausalLM",
25+
max_seq_length=64,
26+
load_in_4bit=False,
27+
dtype=dtype,
28+
)
29+
except OSError as exc: # pragma: no cover - offline CI
30+
pytest.skip(f"Requires access to tiny llama checkpoint: {exc}")
31+
32+
model.to(device)
33+
34+
dataset = Dataset.from_dict(
35+
{
36+
"text": [
37+
"Hello world!",
38+
"Short sample.",
39+
"This is a slightly longer packed example to test batching.",
40+
"Another response to include in the batch.",
41+
]
42+
}
43+
)
44+
45+
training_args = SFTConfig(
46+
per_device_train_batch_size=2,
47+
gradient_accumulation_steps=1,
48+
dataset_text_field="text",
49+
max_length=64,
50+
logging_steps=1,
51+
max_steps=1,
52+
fp16=device.type == "cuda" and not torch.cuda.is_bf16_supported(),
53+
bf16=device.type == "cuda" and torch.cuda.is_bf16_supported(),
54+
dataset_num_proc=1,
55+
output_dir=str(tmp_path),
56+
)
57+
configure_sample_packing(training_args)
58+
59+
trainer = SFTTrainer(
60+
model=model,
61+
processing_class=tokenizer,
62+
train_dataset=dataset,
63+
args=training_args,
64+
)
65+
66+
enable_sample_packing(model, trainer)
67+
68+
dataloader = trainer.get_train_dataloader()
69+
batch = next(iter(dataloader))
70+
71+
model_device = next(model.parameters()).device
72+
73+
for key, value in list(batch.items()):
74+
if torch.is_tensor(value):
75+
batch[key] = value.to(model_device)
76+
77+
from unsloth.models import llama as llama_mod
78+
79+
return model, batch, trainer, llama_mod
80+
81+
82+
def _trim_batch_to_total_tokens(data, total_tokens):
83+
def _trim_tensor(t: torch.Tensor):
84+
if t.ndim >= 2 and t.size(1) > total_tokens:
85+
return t[:, :total_tokens].contiguous()
86+
return t
87+
88+
trimmed = {}
89+
for key, value in data.items():
90+
if torch.is_tensor(value):
91+
trimmed[key] = _trim_tensor(value)
92+
else:
93+
trimmed[key] = value
94+
return trimmed
95+
96+
97+
def test_configure_sample_packing():
98+
config = SimpleNamespace()
99+
configure_sample_packing(config)
100+
101+
assert config.packing is True
102+
assert config.padding_free is True
103+
assert config.remove_unused_columns is False
104+
105+
106+
class _DummyChild(torch.nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
self.max_seq_length = 8
110+
111+
112+
class _DummyModel(torch.nn.Module):
113+
def __init__(self):
114+
super().__init__()
115+
self.max_seq_length = 16
116+
self.child = _DummyChild()
117+
self.config = SimpleNamespace(_attn_implementation="sdpa")
118+
self.generation_config = SimpleNamespace(attn_implementation="sdpa")
119+
120+
121+
class _DummyCollator:
122+
def __init__(self):
123+
self.padding_free = False
124+
self.return_position_ids = False
125+
126+
def torch_call(self, examples):
127+
return {"attention_mask": "mask", "batch": examples}
128+
129+
130+
class _DummyTrainer:
131+
def __init__(self):
132+
self.args = SimpleNamespace(remove_unused_columns=True)
133+
self.data_collator = _DummyCollator()
134+
135+
136+
def test_enable_sample_packing():
137+
model = _DummyModel()
138+
trainer = _DummyTrainer()
139+
140+
enable_sample_packing(model, trainer)
141+
142+
# model hierarchy should now allow packed overlength inputs
143+
assert getattr(model, "_unsloth_allow_packed_overlength") is True
144+
assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
145+
146+
# trainer args are updated to keep the packed metadata
147+
assert trainer.args.remove_unused_columns is False
148+
149+
collator = trainer.data_collator
150+
assert collator.padding_free is True
151+
assert collator.return_position_ids is True
152+
assert getattr(collator, "_unsloth_packing_wrapped") is True
153+
154+
examples = [
155+
{"seq_lengths": [2, 1]},
156+
{"seq_lengths": [3]},
157+
]
158+
batch = collator.torch_call(examples)
159+
160+
# packed lengths are aggregated into a single tensor
161+
assert "packed_seq_lengths" in batch
162+
assert torch.equal(
163+
batch["packed_seq_lengths"],
164+
torch.tensor([2, 1, 3], dtype=torch.int32),
165+
)
166+
167+
# attention_mask is dropped when return_position_ids is set
168+
assert "attention_mask" not in batch
169+
assert batch["batch"] == examples
170+
171+
172+
def test_packing_sdpa(tmp_path):
173+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
174+
model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device)
175+
176+
assert "packed_seq_lengths" in batch
177+
assert "attention_mask" not in batch
178+
assert batch["packed_seq_lengths"].dtype == torch.int32
179+
180+
total_tokens = batch["input_ids"].size(-1)
181+
assert int(batch["packed_seq_lengths"].sum().item()) == total_tokens
182+
183+
packed_tokens = int(batch["packed_seq_lengths"].sum().item())
184+
inputs = _trim_batch_to_total_tokens(batch, packed_tokens)
185+
186+
seq_info = llama_mod.get_packed_info_from_kwargs(
187+
{"packed_seq_lengths": batch["packed_seq_lengths"]},
188+
inputs["input_ids"].shape[0] * inputs["input_ids"].shape[1],
189+
inputs["input_ids"].device,
190+
)
191+
assert seq_info is not None
192+
193+
original_mask = llama_mod.build_sdpa_packed_attention_mask
194+
mask_calls = []
195+
196+
def _capture_mask(seq_info, dtype, device):
197+
mask_calls.append(tuple(seq_info[0].tolist()))
198+
return original_mask(seq_info, dtype=dtype, device=device)
199+
200+
with ExitStack() as stack:
201+
stack.enter_context(patch.object(llama_mod, "HAS_FLASH_ATTENTION", False))
202+
stack.enter_context(patch.object(llama_mod, "HAS_XFORMERS", False))
203+
stack.enter_context(
204+
patch.object(
205+
llama_mod,
206+
"build_sdpa_packed_attention_mask",
207+
side_effect=_capture_mask,
208+
)
209+
)
210+
with torch.no_grad():
211+
outputs = model(**inputs)
212+
213+
assert mask_calls, "SDPA packed mask was not constructed"
214+
assert outputs.loss is not None
215+
216+
if hasattr(trainer, "accelerator"):
217+
trainer.accelerator.free_memory()
218+

0 commit comments

Comments
 (0)