1+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
2+ #
3+ # This program is free software: you can redistribute it and/or modify
4+ # it under the terms of the GNU Lesser General Public License as published by
5+ # the Free Software Foundation, either version 3 of the License, or
6+ # (at your option) any later version.
7+ #
8+ # This program is distributed in the hope that it will be useful,
9+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+ # GNU General Public License for more details.
12+ #
13+ # You should have received a copy of the GNU Lesser General Public License
14+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15+
116"""Unit tests for packed-attention mask helpers with sliding-window logic."""
217
318import math
924
1025
1126def _make_seq_info (lengths ):
12- lengths = torch .tensor (lengths , dtype = torch .int32 )
27+ lengths = torch .tensor (lengths , dtype = torch .int32 )
1328 cu = torch .cat (
1429 [
15- torch .zeros (1 , dtype = torch .int32 ),
16- torch .cumsum (lengths , dim = 0 , dtype = torch .int32 ),
30+ torch .zeros (1 , dtype = torch .int32 ),
31+ torch .cumsum (lengths , dim = 0 , dtype = torch .int32 ),
1732 ]
1833 )
1934 max_len = int (lengths .max ().item ())
@@ -24,15 +39,15 @@ def test_sdpa_packed_attention_mask_sliding_window():
2439 seq_info = _make_seq_info ([5 , 3 ])
2540 mask = packing_utils .build_sdpa_packed_attention_mask (
2641 seq_info ,
27- dtype = torch .float32 ,
28- device = torch .device ("cpu" ),
29- sliding_window = 3 ,
42+ dtype = torch .float32 ,
43+ device = torch .device ("cpu" ),
44+ sliding_window = 3 ,
3045 )
3146
3247 assert mask .shape == (1 , 1 , 8 , 8 )
3348
3449 block_first = mask [0 , 0 , :5 , :5 ]
35- upper = torch .triu (torch .ones_like (block_first ), diagonal = 1 ).bool ()
50+ upper = torch .triu (torch .ones_like (block_first ), diagonal = 1 ).bool ()
3651 assert torch .all (block_first [upper ] == float ("-inf" ))
3752 assert block_first [3 , 0 ].item () == float ("-inf" )
3853 assert block_first [4 , 1 ].item () == float ("-inf" )
@@ -42,7 +57,7 @@ def test_sdpa_packed_attention_mask_sliding_window():
4257
4358def test_xformers_block_mask_sliding_window (monkeypatch ):
4459 class _FakeMask :
45- def __init__ (self , lengths , window = None ):
60+ def __init__ (self , lengths , window = None ):
4661 self .lengths = lengths
4762 self .window = window
4863
@@ -51,14 +66,14 @@ def from_seqlens(cls, lengths):
5166 return cls (tuple (lengths ))
5267
5368 def make_local_attention (self , window_size ):
54- return _FakeMask (self .lengths , window = window_size )
69+ return _FakeMask (self .lengths , window = window_size )
5570
56- monkeypatch .setattr (packing_utils , "_XFormersBlockMask" , _FakeMask , raising = False )
71+ monkeypatch .setattr (packing_utils , "_XFormersBlockMask" , _FakeMask , raising = False )
5772
5873 seq_info = _make_seq_info ([4 , 4 ])
5974 mask = packing_utils .build_xformers_block_causal_mask (
6075 seq_info ,
61- sliding_window = 2 ,
76+ sliding_window = 2 ,
6277 )
6378
6479 assert isinstance (mask , _FakeMask )
@@ -72,13 +87,13 @@ def test_run_attention_sdpa_passes_sliding_window(monkeypatch):
7287 original_builder = attention_dispatch .build_sdpa_packed_attention_mask
7388 captured = {}
7489
75- def _capture_builder (seq_info_arg , * , dtype , device , sliding_window = None ):
90+ def _capture_builder (seq_info_arg , * , dtype , device , sliding_window = None ):
7691 captured ["window" ] = sliding_window
7792 return original_builder (
7893 seq_info_arg ,
79- dtype = dtype ,
80- device = device ,
81- sliding_window = sliding_window ,
94+ dtype = dtype ,
95+ device = device ,
96+ sliding_window = sliding_window ,
8297 )
8398
8499 monkeypatch .setattr (
@@ -94,34 +109,34 @@ def _fake_sdpa(Q, K, V, **kwargs):
94109 monkeypatch .setattr (attention_dispatch , "scaled_dot_product_attention" , _fake_sdpa )
95110
96111 config = attention_dispatch .AttentionConfig (
97- backend = attention_dispatch .SDPA ,
98- n_kv_heads = 1 ,
99- n_groups = 1 ,
112+ backend = attention_dispatch .SDPA ,
113+ n_kv_heads = 1 ,
114+ n_groups = 1 ,
100115 )
101116
102117 context = attention_dispatch .AttentionContext (
103- bsz = 1 ,
104- q_len = 5 ,
105- kv_seq_len = 5 ,
106- n_heads = 1 ,
107- head_dim = 1 ,
108- requires_grad = False ,
109- seq_info = seq_info ,
110- attention_mask = None ,
111- causal_mask = None ,
112- sliding_window = sliding_window ,
118+ bsz = 1 ,
119+ q_len = 5 ,
120+ kv_seq_len = 5 ,
121+ n_heads = 1 ,
122+ head_dim = 1 ,
123+ requires_grad = False ,
124+ seq_info = seq_info ,
125+ attention_mask = None ,
126+ causal_mask = None ,
127+ sliding_window = sliding_window ,
113128 )
114129
115130 Q = torch .zeros (1 , 1 , 5 , 1 )
116131 K = torch .zeros_like (Q )
117132 V = torch .zeros_like (Q )
118133
119134 attention_dispatch .run_attention (
120- config = config ,
121- context = context ,
122- Q = Q ,
123- K = K ,
124- V = V ,
135+ config = config ,
136+ context = context ,
137+ Q = Q ,
138+ K = K ,
139+ V = V ,
125140 )
126141
127142 assert captured ["window" ] == sliding_window
@@ -139,48 +154,54 @@ class _FakeBias:
139154
140155 captured = {}
141156
142- def _fake_builder (seq_info_arg , * , sliding_window = None , base_mask = None ):
157+ def _fake_builder (seq_info_arg , * , sliding_window = None , base_mask = None ):
143158 captured ["window" ] = sliding_window
144159 captured ["base" ] = base_mask
145160 return _FakeBias ()
146161
147- def _fake_attention (Q , K , V , attn_bias = None , ** _ ):
162+ def _fake_attention (Q , K , V , attn_bias = None , ** _ ):
148163 captured ["bias" ] = attn_bias
149164 return torch .zeros_like (Q )
150165
151- monkeypatch .setattr (attention_dispatch , "build_xformers_block_causal_mask" , _fake_builder )
152- monkeypatch .setattr (attention_dispatch , "xformers_attention" , _fake_attention , raising = False )
153- monkeypatch .setattr (attention_dispatch , "XFORMERS_BLOCK_DIAG_CLS" , _FakeBias , raising = False )
166+ monkeypatch .setattr (
167+ attention_dispatch , "build_xformers_block_causal_mask" , _fake_builder
168+ )
169+ monkeypatch .setattr (
170+ attention_dispatch , "xformers_attention" , _fake_attention , raising = False
171+ )
172+ monkeypatch .setattr (
173+ attention_dispatch , "XFORMERS_BLOCK_DIAG_CLS" , _FakeBias , raising = False
174+ )
154175
155176 config = attention_dispatch .AttentionConfig (
156- backend = attention_dispatch .XFORMERS ,
157- n_kv_heads = 1 ,
158- n_groups = 1 ,
177+ backend = attention_dispatch .XFORMERS ,
178+ n_kv_heads = 1 ,
179+ n_groups = 1 ,
159180 )
160181
161182 context = attention_dispatch .AttentionContext (
162- bsz = 1 ,
163- q_len = 4 ,
164- kv_seq_len = 4 ,
165- n_heads = 1 ,
166- head_dim = 1 ,
167- requires_grad = False ,
168- seq_info = seq_info ,
169- attention_mask = None ,
170- causal_mask = None ,
171- sliding_window = sliding_window ,
183+ bsz = 1 ,
184+ q_len = 4 ,
185+ kv_seq_len = 4 ,
186+ n_heads = 1 ,
187+ head_dim = 1 ,
188+ requires_grad = False ,
189+ seq_info = seq_info ,
190+ attention_mask = None ,
191+ causal_mask = None ,
192+ sliding_window = sliding_window ,
172193 )
173194
174195 Q = torch .zeros (1 , 1 , 4 , 1 )
175196 K = torch .zeros_like (Q )
176197 V = torch .zeros_like (Q )
177198
178199 attention_dispatch .run_attention (
179- config = config ,
180- context = context ,
181- Q = Q ,
182- K = K ,
183- V = V ,
200+ config = config ,
201+ context = context ,
202+ Q = Q ,
203+ K = K ,
204+ V = V ,
184205 )
185206
186207 assert captured ["window" ] == sliding_window
@@ -207,10 +228,10 @@ def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):
207228 monkeypatch .setattr (attention_dispatch , "HAS_FLASH_ATTENTION" , True )
208229
209230 config = attention_dispatch .AttentionConfig (
210- backend = attention_dispatch .FLASH_VARLEN ,
211- n_kv_heads = 1 ,
212- n_groups = 1 ,
213- flash_varlen_kwargs = {
231+ backend = attention_dispatch .FLASH_VARLEN ,
232+ n_kv_heads = 1 ,
233+ n_groups = 1 ,
234+ flash_varlen_kwargs = {
214235 "dropout_p" : 0.0 ,
215236 "softmax_scale" : 1.0 ,
216237 "causal" : True ,
@@ -220,29 +241,32 @@ def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):
220241 )
221242
222243 context = attention_dispatch .AttentionContext (
223- bsz = 1 ,
224- q_len = 4 ,
225- kv_seq_len = 4 ,
226- n_heads = 1 ,
227- head_dim = 2 ,
228- requires_grad = False ,
229- seq_info = seq_info ,
230- attention_mask = None ,
231- causal_mask = None ,
232- sliding_window = sliding_window ,
244+ bsz = 1 ,
245+ q_len = 4 ,
246+ kv_seq_len = 4 ,
247+ n_heads = 1 ,
248+ head_dim = 2 ,
249+ requires_grad = False ,
250+ seq_info = seq_info ,
251+ attention_mask = None ,
252+ causal_mask = None ,
253+ sliding_window = sliding_window ,
233254 )
234255
235256 Q = torch .zeros (1 , 1 , 4 , 2 )
236257 K = torch .zeros_like (Q )
237258 V = torch .zeros_like (Q )
238259
239260 attention_dispatch .run_attention (
240- config = config ,
241- context = context ,
242- Q = Q ,
243- K = K ,
244- V = V ,
261+ config = config ,
262+ context = context ,
263+ Q = Q ,
264+ K = K ,
265+ V = V ,
245266 )
246267
247268 assert captured ["kwargs" ]["softcap" ] == softcap
248269 assert captured ["kwargs" ]["window_size" ] == window_tuple
270+
271+
272+ """Unit tests for packed-attention mask helpers with sliding-window logic."""
0 commit comments