@@ -57,6 +57,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
5757 return draft_token_ids
5858
5959
60+ def get_acceptance_sampler (
61+ posterior_threshold : float = 0.03 ,
62+ posterior_alpha : float = 0.9 ,
63+ disable_bonus_tokens : bool = False ,
64+ strict_mode : bool = False ,
65+ ) -> TypicalAcceptanceSampler :
66+ """
67+ Initializes and returns a TypicalAcceptanceSampler.
68+ """
69+ return TypicalAcceptanceSampler (posterior_threshold , posterior_alpha ,
70+ disable_bonus_tokens , strict_mode )
71+
72+
6073@pytest .mark .parametrize ("k" , list (range (1 , 6 )))
6174@pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
6275@pytest .mark .parametrize ("batch_size" , list (range (1 , 32 )))
@@ -69,7 +82,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
6982 different combinations of k, vocab_size, batch_size and num devices.
7083 """
7184 torch .set_default_device (device )
72- typical_acceptance_sampler = TypicalAcceptanceSampler ()
85+ typical_acceptance_sampler = get_acceptance_sampler ()
7386 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
7487 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
7588 bonus_token_ids = torch .randint (low = 0 ,
@@ -81,7 +94,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
8194 size = (batch_size , k ),
8295 dtype = torch .int64 )
8396 # Verify that sampling succeeds for all cases.
84- typical_acceptance_sampler (target_probs , bonus_token_ids , draft_token_ids )
97+ typical_acceptance_sampler (target_probs ,
98+ bonus_token_ids ,
99+ draft_probs = None ,
100+ draft_token_ids = draft_token_ids )
85101
86102
87103@pytest .mark .parametrize ("above_or_below_vocab_range" , ["above" , "below" ])
@@ -99,7 +115,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
99115 batch_size = 5
100116 vocab_size = 30_000
101117 torch .set_default_device (device )
102- typical_acceptance_sampler = TypicalAcceptanceSampler (strict_mode = True )
118+ typical_acceptance_sampler = get_acceptance_sampler (strict_mode = True )
103119 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
104120 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
105121 bonus_token_ids = torch .randint (low = 0 ,
@@ -130,8 +146,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
130146 oob_token_ids [0 ][0 ] = rogue_token_id
131147
132148 with pytest .raises (AssertionError ):
133- typical_acceptance_sampler (target_probs , bonus_token_ids ,
134- draft_token_ids )
149+ typical_acceptance_sampler (target_probs ,
150+ bonus_token_ids ,
151+ draft_probs = None ,
152+ draft_token_ids = draft_token_ids )
135153
136154
137155@pytest .mark .parametrize ("seed" , list (range (10 )))
@@ -156,7 +174,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
156174 batch_size = 5
157175 vocab_size = 30_000
158176 torch .set_default_device (device )
159- typical_acceptance_sampler = TypicalAcceptanceSampler (
177+ typical_acceptance_sampler = get_acceptance_sampler (
160178 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
161179 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
162180 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
@@ -168,9 +186,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
168186 high = vocab_size ,
169187 size = (batch_size , 1 ),
170188 dtype = torch .int64 )
171- output_token_ids = typical_acceptance_sampler (target_probs ,
172- bonus_token_ids ,
173- draft_token_ids )
189+ output_token_ids = typical_acceptance_sampler (
190+ target_probs ,
191+ bonus_token_ids ,
192+ draft_probs = None ,
193+ draft_token_ids = draft_token_ids )
174194 # We are using a uniform target probability distribution.
175195 # For a uniform distribution the entropy is very high and it
176196 # should lead to all draft tokens being accepted. Verify that.
@@ -208,7 +228,7 @@ def test_temperature_zero_target_distribution(seed: int,
208228 vocab_size = 30_000
209229 torch .set_default_device (device )
210230
211- typical_acceptance_sampler = TypicalAcceptanceSampler (
231+ typical_acceptance_sampler = get_acceptance_sampler (
212232 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
213233 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
214234 # Simulate temperature 0 probability distribution for target probabilities
@@ -229,9 +249,11 @@ def test_temperature_zero_target_distribution(seed: int,
229249 # 1.0 tokens in the target distribution we will reject all of them and
230250 # fallback to the greedy sampling for selecting 1 token for each sequence.
231251 # Verify the same.
232- output_token_ids = typical_acceptance_sampler (target_probs ,
233- bonus_token_ids ,
234- draft_token_ids )
252+ output_token_ids = typical_acceptance_sampler (
253+ target_probs ,
254+ bonus_token_ids ,
255+ draft_probs = None ,
256+ draft_token_ids = draft_token_ids )
235257 assert output_token_ids .shape [0 ] == batch_size
236258 assert output_token_ids .shape [1 ] == (k + 1 )
237259 assert torch .all (output_token_ids [:, - 1 ] == - 1 )
@@ -266,7 +288,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
266288 batch_size = 4
267289 vocab_size = 30_000
268290 torch .set_default_device (device )
269- typical_acceptance_sampler = TypicalAcceptanceSampler (
291+ typical_acceptance_sampler = get_acceptance_sampler (
270292 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
271293 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
272294 # For sequences 0 and 2 set the distribution to a temperature
@@ -282,9 +304,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
282304 high = vocab_size ,
283305 size = (batch_size , 1 ),
284306 dtype = torch .int64 )
285- output_token_ids = typical_acceptance_sampler (target_probs ,
286- bonus_token_ids ,
287- draft_token_ids )
307+ output_token_ids = typical_acceptance_sampler (
308+ target_probs ,
309+ bonus_token_ids ,
310+ draft_probs = None ,
311+ draft_token_ids = draft_token_ids )
288312 # verify the shape of output_token_ids
289313 assert output_token_ids .shape [0 ] == batch_size
290314 assert output_token_ids .shape [1 ] == (k + 1 )
@@ -331,7 +355,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
331355 batch_size = 1
332356 vocab_size = 30_000
333357 torch .set_default_device (device )
334- typical_acceptance_sampler = TypicalAcceptanceSampler (
358+ typical_acceptance_sampler = get_acceptance_sampler (
335359 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
336360 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
337361 # Create a temperature zero target probability distribution and ensure
@@ -344,9 +368,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
344368 high = vocab_size ,
345369 size = (batch_size , 1 ),
346370 dtype = torch .int64 )
347- output_token_ids = typical_acceptance_sampler (target_probs ,
348- bonus_token_ids ,
349- draft_token_ids )
371+ output_token_ids = typical_acceptance_sampler (
372+ target_probs ,
373+ bonus_token_ids ,
374+ draft_probs = None ,
375+ draft_token_ids = draft_token_ids )
350376 assert output_token_ids .shape [0 ] == batch_size
351377 assert output_token_ids .shape [1 ] == (k + 1 )
352378 assert torch .all (output_token_ids [:, 0 :- 1 ] == draft_token_ids )
@@ -362,9 +388,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
362388 batch_size , k , vocab_size , zero_temperature_token_ids )
363389 draft_token_ids = torch .cat (
364390 (draft_token_ids [:, :2 ], draft_token_ids_to_replace [:, - 3 :]), dim = 1 )
365- output_token_ids = typical_acceptance_sampler (target_probs ,
366- bonus_token_ids ,
367- draft_token_ids )
391+ output_token_ids = typical_acceptance_sampler (
392+ target_probs ,
393+ bonus_token_ids ,
394+ draft_probs = None ,
395+ draft_token_ids = draft_token_ids )
368396 assert output_token_ids .shape [0 ] == batch_size
369397 assert output_token_ids .shape [1 ] == (k + 1 )
370398 assert torch .all (output_token_ids [:, :2 ] == draft_token_ids [:, :2 ])
@@ -389,7 +417,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
389417 batch_size = 1
390418 vocab_size = 30_000
391419 torch .set_default_device (device )
392- typical_acceptance_sampler = TypicalAcceptanceSampler (
420+ typical_acceptance_sampler = get_acceptance_sampler (
393421 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
394422 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
395423 # Simulate temperature 0 probability distribution for target
@@ -407,9 +435,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
407435 high = vocab_size ,
408436 size = (batch_size , 1 ),
409437 dtype = torch .int64 )
410- output_token_ids = typical_acceptance_sampler (target_probs ,
411- bonus_token_ids ,
412- draft_token_ids )
438+ output_token_ids = typical_acceptance_sampler (
439+ target_probs ,
440+ bonus_token_ids ,
441+ draft_probs = None ,
442+ draft_token_ids = draft_token_ids )
413443 assert output_token_ids .shape [0 ] == batch_size
414444 assert output_token_ids .shape [1 ] == (k + 1 )
415445 assert torch .all (output_token_ids [:, 1 :- 1 ] == - 1 )
@@ -423,9 +453,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
423453 posterior_threshold = 0.0 ,
424454 posterior_alpha = 0.0 )
425455 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
426- output_token_ids = typical_acceptance_sampler (target_probs ,
427- bonus_token_ids ,
428- draft_token_ids )
456+ output_token_ids = typical_acceptance_sampler (
457+ target_probs ,
458+ bonus_token_ids ,
459+ draft_probs = None ,
460+ draft_token_ids = draft_token_ids )
429461 assert output_token_ids .shape [0 ] == batch_size
430462 assert output_token_ids .shape [1 ] == (k + 1 )
431463 assert torch .all (output_token_ids [:, 0 :- 1 ] == draft_token_ids )
@@ -456,7 +488,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
456488 batch_size = 5
457489 vocab_size = 30_000
458490 torch .set_default_device (device )
459- typical_acceptance_sampler = TypicalAcceptanceSampler (
491+ typical_acceptance_sampler = get_acceptance_sampler (
460492 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
461493 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
462494 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
0 commit comments