66import torch
77import torch .nn .functional as F
88
9+ from tests .v1 .sample .utils import create_allowed_token_ids
910from vllm .platforms import current_platform
1011from vllm .v1 .sample .logits_processor import LogitsProcessors
1112from vllm .v1 .sample .metadata import SamplingMetadata
@@ -21,7 +22,9 @@ def rejection_sampler():
2122
2223
2324def create_logits_tensor (
24- output_token_ids : list [list [int ]], vocab_size : int = 100
25+ output_token_ids : list [list [int ]],
26+ vocab_size : int = 100 ,
27+ token_idx_to_override : Optional [int ] = None ,
2528) -> torch .Tensor :
2629 """Helper function to create logits tensor that
2730 will produce desired token ids on argmax"""
@@ -33,15 +36,25 @@ def create_logits_tensor(
3336 for j , token_id in enumerate (tokens ):
3437 logits [start_loc + j , token_id ] = 100.0
3538 start_loc += len (tokens )
39+ if token_idx_to_override :
40+ logits [:, token_idx_to_override ] = 99.0
3641 return logits
3742
3843
3944def create_sampling_metadata (
4045 all_greedy : bool ,
46+ output_token_ids : Optional [list [list [int ]]] = None ,
47+ prompt_token_ids : Optional [torch .Tensor ] = None ,
48+ spec_token_ids : Optional [torch .Tensor ] = None ,
4149 temperature : Optional [torch .Tensor ] = None ,
4250 top_k : Optional [torch .Tensor ] = None ,
4351 top_p : Optional [torch .Tensor ] = None ,
4452 generators : Optional [dict [int , Any ]] = None ,
53+ frequency_penalties : Optional [list [float ]] = None ,
54+ presence_penalties : Optional [list [float ]] = None ,
55+ repetition_penalties : Optional [list [float ]] = None ,
56+ bad_words_token_ids : Optional [dict [int , list [list [int ]]]] = None ,
57+ allowed_token_ids_mask : Optional [torch .Tensor ] = None ,
4558) -> SamplingMetadata :
4659 """Create a v1 sampling metadata object with all_greedy set
4760 to the given value. Either all greedy or all random sampling
@@ -53,6 +66,21 @@ def create_sampling_metadata(
5366 else :
5467 assert temperature is not None
5568
69+ if any ([frequency_penalties , presence_penalties , repetition_penalties ]):
70+ no_penalties = False
71+
72+ assert output_token_ids
73+ assert len (output_token_ids ) > 0
74+
75+ frequency_penalties = torch .tensor (frequency_penalties , device = DEVICE )
76+ presence_penalties = torch .tensor (presence_penalties , device = DEVICE )
77+ repetition_penalties = torch .tensor (repetition_penalties , device = DEVICE )
78+ else :
79+ no_penalties = True
80+ frequency_penalties = torch .tensor ([])
81+ presence_penalties = torch .tensor ([])
82+ repetition_penalties = torch .tensor ([])
83+
5684 return SamplingMetadata (
5785 temperature = temperature ,
5886 all_greedy = all_greedy ,
@@ -61,14 +89,15 @@ def create_sampling_metadata(
6189 top_k = top_k ,
6290 generators = generators ,
6391 max_num_logprobs = 0 ,
64- no_penalties = False ,
65- prompt_token_ids = None ,
66- frequency_penalties = torch .tensor ([]),
67- presence_penalties = torch .tensor ([]),
68- repetition_penalties = torch .tensor ([]),
69- output_token_ids = [],
70- allowed_token_ids_mask = None ,
71- bad_words_token_ids = {},
92+ no_penalties = no_penalties ,
93+ prompt_token_ids = prompt_token_ids ,
94+ frequency_penalties = frequency_penalties ,
95+ presence_penalties = presence_penalties ,
96+ repetition_penalties = repetition_penalties ,
97+ output_token_ids = [] if output_token_ids is None else output_token_ids ,
98+ spec_token_ids = [] if spec_token_ids is None else spec_token_ids ,
99+ allowed_token_ids_mask = allowed_token_ids_mask ,
100+ bad_words_token_ids = {} if bad_words_token_ids is None else bad_words_token_ids ,
72101 logitsprocs = LogitsProcessors (),
73102 )
74103
@@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p):
611640 unmasked_indices = top_p_indices ,
612641 sampling_metadata = sampling_metadata ,
613642 )
643+
644+
645+ ########################### Tests for Logit Processors ###################
646+ def test_frequency_penalties (rejection_sampler ):
647+ """Test rejection sampling with frequency penalties"""
648+ spec_tokens = [[1 , 1 , 1 ], [], [1 , 1 , 1 ]]
649+ output_tokens = [[1 , 1 , 1 , 1 ], [7 ], [1 , 1 , 1 , 1 ]] # 1, 7 and 1 are the bonus tokens
650+
651+ num_requsts = len (spec_tokens )
652+ logits = create_logits_tensor (output_tokens , token_idx_to_override = 15 )
653+ metadata = create_sampling_metadata (
654+ all_greedy = True ,
655+ output_token_ids = [[2 ], [3 ], [4 ]],
656+ spec_token_ids = spec_tokens ,
657+ prompt_token_ids = torch .tensor ([[5 , 6 , 7 ], [6 , 7 , 8 ], [7 , 8 , 9 ]], device = DEVICE ),
658+ frequency_penalties = [1.5 , 1.5 , 0.7 ],
659+ presence_penalties = [0.0 ] * num_requsts ,
660+ repetition_penalties = [1.0 ] * num_requsts ,
661+ )
662+ bonus_token_tensor = torch .tensor (
663+ [output_tokens [i ][- 1 ] for i in range (len (output_tokens ))], device = logits .device
664+ )
665+ spec_decode_metadata = SpecDecodeMetadata .make_dummy (
666+ spec_tokens , device = logits .device
667+ )
668+ output = rejection_sampler (
669+ spec_decode_metadata ,
670+ draft_probs = None ,
671+ target_logits = logits ,
672+ bonus_token_ids = bonus_token_tensor ,
673+ sampling_metadata = metadata ,
674+ )
675+ expected = torch .tensor (
676+ [[1 , 15 , - 1 , - 1 ], [7 , - 1 , - 1 , - 1 ], [1 , 1 , 15 , - 1 ]],
677+ dtype = torch .int ,
678+ device = logits .device ,
679+ )
680+ assert torch .equal (output , expected )
681+
682+
683+ def test_bad_words (rejection_sampler ):
684+ """Test rejection sampling with bad words constraints"""
685+ spec_tokens = [[1 , 2 , 3 ], [1 , 15 , 3 ], [1 , 2 , 3 ]]
686+ output_tokens = [[1 , 2 , 3 , 4 ], [1 , 2 , 3 , 4 ], [1 , 2 , 3 , 4 ]]
687+
688+ logits = create_logits_tensor (output_tokens , token_idx_to_override = 15 )
689+ metadata = create_sampling_metadata (
690+ all_greedy = True ,
691+ output_token_ids = [[2 ], [3 ], [4 ]],
692+ spec_token_ids = spec_tokens ,
693+ bad_words_token_ids = {
694+ 0 : [
695+ [
696+ 2 ,
697+ ]
698+ ],
699+ 1 : [
700+ [
701+ 2 ,
702+ ]
703+ ],
704+ # Do not apply bad words to the last request
705+ },
706+ )
707+ bonus_token_tensor = torch .tensor (
708+ [output_tokens [i ][- 1 ] for i in range (len (output_tokens ))], device = logits .device
709+ )
710+ spec_decode_metadata = SpecDecodeMetadata .make_dummy (
711+ spec_tokens , device = logits .device
712+ )
713+ output = rejection_sampler (
714+ spec_decode_metadata ,
715+ draft_probs = None ,
716+ target_logits = logits ,
717+ bonus_token_ids = bonus_token_tensor ,
718+ sampling_metadata = metadata ,
719+ )
720+
721+ expected = torch .tensor (
722+ [[1 , 15 , - 1 , - 1 ], [1 , 15 , 3 , 4 ], [1 , 2 , 3 , 4 ]],
723+ dtype = torch .int ,
724+ device = logits .device ,
725+ )
726+ assert torch .equal (output , expected )
727+
728+
729+ def test_allowed_token_ids (rejection_sampler ):
730+ """Test rejection sampling with allowed token ids"""
731+ spec_tokens = [[1 , 2 , 10 ], [10 , 5 , 3 ], [7 , 10 , 12 ]]
732+ output_tokens = [[1 , 2 , 10 , 5 ], [10 , 5 , 10 , 5 ], [7 , 10 , 12 , 5 ]]
733+ # Not allowed tokens:
734+ # 0: 0-4
735+ # 1: 1-5
736+ # 2: 2-6
737+ num_allowed_token_ids = 5
738+
739+ # Use the token 15 as the sampler choose if a token rejected
740+ logits = create_logits_tensor (output_tokens , token_idx_to_override = 15 )
741+
742+ batch_size = len (output_tokens )
743+ _ , vocab_size = logits .size ()
744+ mask = create_allowed_token_ids (
745+ batch_size = batch_size ,
746+ vocab_size = vocab_size ,
747+ num_allowed_token_ids = num_allowed_token_ids ,
748+ device = logits .device ,
749+ )
750+ metadata = create_sampling_metadata (
751+ all_greedy = True ,
752+ output_token_ids = [[], [], []],
753+ spec_token_ids = spec_tokens ,
754+ allowed_token_ids_mask = mask ,
755+ )
756+ bonus_token_tensor = torch .tensor (
757+ [output_tokens [i ][- 1 ] for i in range (len (output_tokens ))], device = logits .device
758+ )
759+ spec_decode_metadata = SpecDecodeMetadata .make_dummy (
760+ spec_tokens , device = logits .device
761+ )
762+ output = rejection_sampler (
763+ spec_decode_metadata ,
764+ draft_probs = None ,
765+ target_logits = logits ,
766+ bonus_token_ids = bonus_token_tensor ,
767+ sampling_metadata = metadata ,
768+ )
769+
770+ expected = torch .tensor (
771+ [[15 , - 1 , - 1 , - 1 ], [10 , 5 , 10 , - 1 ], [7 , 10 , 12 , 5 ]],
772+ dtype = torch .int ,
773+ device = logits .device ,
774+ )
775+ assert torch .equal (output , expected )
0 commit comments