Skip to content

Commit c274827

Browse files
committed
LGPL license headers
1 parent f503df4 commit c274827

File tree

14 files changed

+226
-306
lines changed

14 files changed

+226
-306
lines changed
Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
2+
#
3+
# This program is free software: you can redistribute it and/or modify
4+
# it under the terms of the GNU Lesser General Public License as published by
5+
# the Free Software Foundation, either version 3 of the License, or
6+
# (at your option) any later version.
7+
#
8+
# This program is distributed in the hope that it will be useful,
9+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
# GNU General Public License for more details.
12+
#
13+
# You should have received a copy of the GNU Lesser General Public License
14+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
15+
116
"""Unit tests for packed-attention mask helpers with sliding-window logic."""
217

318
import math
@@ -9,11 +24,11 @@
924

1025

1126
def _make_seq_info(lengths):
12-
lengths = torch.tensor(lengths, dtype=torch.int32)
27+
lengths = torch.tensor(lengths, dtype = torch.int32)
1328
cu = torch.cat(
1429
[
15-
torch.zeros(1, dtype=torch.int32),
16-
torch.cumsum(lengths, dim=0, dtype=torch.int32),
30+
torch.zeros(1, dtype = torch.int32),
31+
torch.cumsum(lengths, dim = 0, dtype = torch.int32),
1732
]
1833
)
1934
max_len = int(lengths.max().item())
@@ -24,15 +39,15 @@ def test_sdpa_packed_attention_mask_sliding_window():
2439
seq_info = _make_seq_info([5, 3])
2540
mask = packing_utils.build_sdpa_packed_attention_mask(
2641
seq_info,
27-
dtype=torch.float32,
28-
device=torch.device("cpu"),
29-
sliding_window=3,
42+
dtype = torch.float32,
43+
device = torch.device("cpu"),
44+
sliding_window = 3,
3045
)
3146

3247
assert mask.shape == (1, 1, 8, 8)
3348

3449
block_first = mask[0, 0, :5, :5]
35-
upper = torch.triu(torch.ones_like(block_first), diagonal=1).bool()
50+
upper = torch.triu(torch.ones_like(block_first), diagonal = 1).bool()
3651
assert torch.all(block_first[upper] == float("-inf"))
3752
assert block_first[3, 0].item() == float("-inf")
3853
assert block_first[4, 1].item() == float("-inf")
@@ -42,7 +57,7 @@ def test_sdpa_packed_attention_mask_sliding_window():
4257

4358
def test_xformers_block_mask_sliding_window(monkeypatch):
4459
class _FakeMask:
45-
def __init__(self, lengths, window=None):
60+
def __init__(self, lengths, window = None):
4661
self.lengths = lengths
4762
self.window = window
4863

@@ -51,14 +66,14 @@ def from_seqlens(cls, lengths):
5166
return cls(tuple(lengths))
5267

5368
def make_local_attention(self, window_size):
54-
return _FakeMask(self.lengths, window=window_size)
69+
return _FakeMask(self.lengths, window = window_size)
5570

56-
monkeypatch.setattr(packing_utils, "_XFormersBlockMask", _FakeMask, raising=False)
71+
monkeypatch.setattr(packing_utils, "_XFormersBlockMask", _FakeMask, raising = False)
5772

5873
seq_info = _make_seq_info([4, 4])
5974
mask = packing_utils.build_xformers_block_causal_mask(
6075
seq_info,
61-
sliding_window=2,
76+
sliding_window = 2,
6277
)
6378

6479
assert isinstance(mask, _FakeMask)
@@ -72,13 +87,13 @@ def test_run_attention_sdpa_passes_sliding_window(monkeypatch):
7287
original_builder = attention_dispatch.build_sdpa_packed_attention_mask
7388
captured = {}
7489

75-
def _capture_builder(seq_info_arg, *, dtype, device, sliding_window=None):
90+
def _capture_builder(seq_info_arg, *, dtype, device, sliding_window = None):
7691
captured["window"] = sliding_window
7792
return original_builder(
7893
seq_info_arg,
79-
dtype=dtype,
80-
device=device,
81-
sliding_window=sliding_window,
94+
dtype = dtype,
95+
device = device,
96+
sliding_window = sliding_window,
8297
)
8398

8499
monkeypatch.setattr(
@@ -94,34 +109,34 @@ def _fake_sdpa(Q, K, V, **kwargs):
94109
monkeypatch.setattr(attention_dispatch, "scaled_dot_product_attention", _fake_sdpa)
95110

96111
config = attention_dispatch.AttentionConfig(
97-
backend=attention_dispatch.SDPA,
98-
n_kv_heads=1,
99-
n_groups=1,
112+
backend = attention_dispatch.SDPA,
113+
n_kv_heads = 1,
114+
n_groups = 1,
100115
)
101116

102117
context = attention_dispatch.AttentionContext(
103-
bsz=1,
104-
q_len=5,
105-
kv_seq_len=5,
106-
n_heads=1,
107-
head_dim=1,
108-
requires_grad=False,
109-
seq_info=seq_info,
110-
attention_mask=None,
111-
causal_mask=None,
112-
sliding_window=sliding_window,
118+
bsz = 1,
119+
q_len = 5,
120+
kv_seq_len = 5,
121+
n_heads = 1,
122+
head_dim = 1,
123+
requires_grad = False,
124+
seq_info = seq_info,
125+
attention_mask = None,
126+
causal_mask = None,
127+
sliding_window = sliding_window,
113128
)
114129

115130
Q = torch.zeros(1, 1, 5, 1)
116131
K = torch.zeros_like(Q)
117132
V = torch.zeros_like(Q)
118133

119134
attention_dispatch.run_attention(
120-
config=config,
121-
context=context,
122-
Q=Q,
123-
K=K,
124-
V=V,
135+
config = config,
136+
context = context,
137+
Q = Q,
138+
K = K,
139+
V = V,
125140
)
126141

127142
assert captured["window"] == sliding_window
@@ -139,48 +154,54 @@ class _FakeBias:
139154

140155
captured = {}
141156

142-
def _fake_builder(seq_info_arg, *, sliding_window=None, base_mask=None):
157+
def _fake_builder(seq_info_arg, *, sliding_window = None, base_mask = None):
143158
captured["window"] = sliding_window
144159
captured["base"] = base_mask
145160
return _FakeBias()
146161

147-
def _fake_attention(Q, K, V, attn_bias=None, **_):
162+
def _fake_attention(Q, K, V, attn_bias = None, **_):
148163
captured["bias"] = attn_bias
149164
return torch.zeros_like(Q)
150165

151-
monkeypatch.setattr(attention_dispatch, "build_xformers_block_causal_mask", _fake_builder)
152-
monkeypatch.setattr(attention_dispatch, "xformers_attention", _fake_attention, raising=False)
153-
monkeypatch.setattr(attention_dispatch, "XFORMERS_BLOCK_DIAG_CLS", _FakeBias, raising=False)
166+
monkeypatch.setattr(
167+
attention_dispatch, "build_xformers_block_causal_mask", _fake_builder
168+
)
169+
monkeypatch.setattr(
170+
attention_dispatch, "xformers_attention", _fake_attention, raising = False
171+
)
172+
monkeypatch.setattr(
173+
attention_dispatch, "XFORMERS_BLOCK_DIAG_CLS", _FakeBias, raising = False
174+
)
154175

155176
config = attention_dispatch.AttentionConfig(
156-
backend=attention_dispatch.XFORMERS,
157-
n_kv_heads=1,
158-
n_groups=1,
177+
backend = attention_dispatch.XFORMERS,
178+
n_kv_heads = 1,
179+
n_groups = 1,
159180
)
160181

161182
context = attention_dispatch.AttentionContext(
162-
bsz=1,
163-
q_len=4,
164-
kv_seq_len=4,
165-
n_heads=1,
166-
head_dim=1,
167-
requires_grad=False,
168-
seq_info=seq_info,
169-
attention_mask=None,
170-
causal_mask=None,
171-
sliding_window=sliding_window,
183+
bsz = 1,
184+
q_len = 4,
185+
kv_seq_len = 4,
186+
n_heads = 1,
187+
head_dim = 1,
188+
requires_grad = False,
189+
seq_info = seq_info,
190+
attention_mask = None,
191+
causal_mask = None,
192+
sliding_window = sliding_window,
172193
)
173194

174195
Q = torch.zeros(1, 1, 4, 1)
175196
K = torch.zeros_like(Q)
176197
V = torch.zeros_like(Q)
177198

178199
attention_dispatch.run_attention(
179-
config=config,
180-
context=context,
181-
Q=Q,
182-
K=K,
183-
V=V,
200+
config = config,
201+
context = context,
202+
Q = Q,
203+
K = K,
204+
V = V,
184205
)
185206

186207
assert captured["window"] == sliding_window
@@ -207,10 +228,10 @@ def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):
207228
monkeypatch.setattr(attention_dispatch, "HAS_FLASH_ATTENTION", True)
208229

209230
config = attention_dispatch.AttentionConfig(
210-
backend=attention_dispatch.FLASH_VARLEN,
211-
n_kv_heads=1,
212-
n_groups=1,
213-
flash_varlen_kwargs={
231+
backend = attention_dispatch.FLASH_VARLEN,
232+
n_kv_heads = 1,
233+
n_groups = 1,
234+
flash_varlen_kwargs = {
214235
"dropout_p": 0.0,
215236
"softmax_scale": 1.0,
216237
"causal": True,
@@ -220,29 +241,32 @@ def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):
220241
)
221242

222243
context = attention_dispatch.AttentionContext(
223-
bsz=1,
224-
q_len=4,
225-
kv_seq_len=4,
226-
n_heads=1,
227-
head_dim=2,
228-
requires_grad=False,
229-
seq_info=seq_info,
230-
attention_mask=None,
231-
causal_mask=None,
232-
sliding_window=sliding_window,
244+
bsz = 1,
245+
q_len = 4,
246+
kv_seq_len = 4,
247+
n_heads = 1,
248+
head_dim = 2,
249+
requires_grad = False,
250+
seq_info = seq_info,
251+
attention_mask = None,
252+
causal_mask = None,
253+
sliding_window = sliding_window,
233254
)
234255

235256
Q = torch.zeros(1, 1, 4, 2)
236257
K = torch.zeros_like(Q)
237258
V = torch.zeros_like(Q)
238259

239260
attention_dispatch.run_attention(
240-
config=config,
241-
context=context,
242-
Q=Q,
243-
K=K,
244-
V=V,
261+
config = config,
262+
context = context,
263+
Q = Q,
264+
K = K,
265+
V = V,
245266
)
246267

247268
assert captured["kwargs"]["softcap"] == softcap
248269
assert captured["kwargs"]["window_size"] == window_tuple
270+
271+
272+
"""Unit tests for packed-attention mask helpers with sliding-window logic."""

0 commit comments

Comments
 (0)