11from itertools import cycle
2- from typing import List , Optional , Tuple
2+ from typing import List , Optional , Sequence , Tuple , Union
33
44import pytest
55
66from vllm import LLM , SamplingParams
77from vllm .model_executor .utils import set_random_seed
8+ from vllm .sequence import PromptLogprobs , SampleLogprobs
89
910from ...conftest import cleanup
10- from ...models .utils import check_logprobs_close , check_outputs_equal
11+ from ...models .utils import (TokensTextLogprobs ,
12+ TokensTextLogprobsPromptLogprobs ,
13+ check_logprobs_close , check_outputs_equal )
1114from ...utils import RemoteOpenAIServer
1215
1316PROMPTS = [
@@ -81,45 +84,77 @@ def get_output_from_llm_generator(
8184 return tokens , token_ids , acceptance_rate
8285
8386
84- def run_logprob_correctness_test (vllm_runner ,
85- common_llm_kwargs ,
86- per_test_common_llm_kwargs ,
87- baseline_llm_kwargs ,
88- test_llm_kwargs ,
89- batch_size : int ,
90- max_output_len : int ,
91- seed : Optional [int ] = 0 ,
92- temperature : float = 0.0 ,
93- logprobs : int = 1 ):
94- org_args = {
95- ** common_llm_kwargs ,
96- ** per_test_common_llm_kwargs ,
97- ** baseline_llm_kwargs ,
98- }
99-
100- sd_args = {
101- ** common_llm_kwargs ,
102- ** per_test_common_llm_kwargs ,
103- ** test_llm_kwargs ,
104- }
105-
106- prompts = [prompt for prompt , _ in zip (cycle (PROMPTS ), range (batch_size ))]
107-
108- sampling_params = SamplingParams (temperature = temperature ,
109- max_tokens = max_output_len ,
110- seed = seed ,
111- logprobs = logprobs )
112-
113- with vllm_runner (** org_args ) as vllm_model :
114- org_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
115-
116- with vllm_runner (** sd_args ) as vllm_model :
117- sd_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
118-
119- check_logprobs_close (outputs_0_lst = org_outputs ,
120- outputs_1_lst = sd_outputs ,
121- name_0 = "org" ,
122- name_1 = "sd" )
87+ def check_logprobs_correctness (
88+ spec_outputs : Sequence [Union [TokensTextLogprobs ,
89+ TokensTextLogprobsPromptLogprobs ]],
90+ baseline_outputs : Sequence [Union [TokensTextLogprobs ,
91+ TokensTextLogprobsPromptLogprobs ]],
92+ disable_logprobs : bool = False ,
93+ ):
94+ """Compare sampled and prompt logprobs between baseline and spec decoding
95+ """
96+ if not disable_logprobs :
97+ return check_logprobs_close (
98+ outputs_0_lst = baseline_outputs ,
99+ outputs_1_lst = spec_outputs ,
100+ name_0 = "org" ,
101+ name_1 = "sd" ,
102+ )
103+
104+ # Check correctness when disable_logprobs == True
105+ for spec_output , baseline_output in zip (spec_outputs , baseline_outputs ):
106+ # Check generated token logprobs.
107+ spec_logprobs = spec_output [2 ]
108+ baseline_logprobs = baseline_output [2 ]
109+ _check_logprobs_when_output_disabled (spec_logprobs ,
110+ baseline_logprobs ,
111+ is_prompt_logprobs = False )
112+
113+ # Check prompt logprobs too, if they exist
114+ if len (baseline_output ) == 4 :
115+ assert len (spec_output ) == 4
116+ spec_prompt_logprobs = spec_output [3 ]
117+ baseline_prompt_logprobs = baseline_output [3 ]
118+ _check_logprobs_when_output_disabled (spec_prompt_logprobs ,
119+ baseline_prompt_logprobs ,
120+ is_prompt_logprobs = True )
121+
122+
123+ def _check_logprobs_when_output_disabled (
124+ spec_logprobs : Union [Optional [PromptLogprobs ], SampleLogprobs ],
125+ baseline_logprobs : Union [Optional [PromptLogprobs ], SampleLogprobs ],
126+ is_prompt_logprobs : bool = False ,
127+ ):
128+ # Prompt logprobs are optional
129+ if is_prompt_logprobs and baseline_logprobs is None :
130+ assert spec_logprobs is None
131+ return
132+
133+ assert spec_logprobs is not None
134+ assert baseline_logprobs is not None
135+ assert len (spec_logprobs ) == len (baseline_logprobs )
136+
137+ # For each generated position of the sequence.
138+ for pos , (spec_pos_logprobs , baseline_pos_logprobs ) in enumerate (
139+ zip (spec_logprobs , baseline_logprobs )):
140+
141+ # First prompt logprob is expected to be None
142+ if is_prompt_logprobs and baseline_pos_logprobs is None :
143+ assert spec_pos_logprobs is None
144+ assert pos == 0
145+ continue
146+
147+ assert spec_pos_logprobs is not None
148+ assert baseline_pos_logprobs is not None
149+
150+ # When disabled, the 1 logprob is returned with dummy values for the
151+ # score and rank, but the token id should match the baseline model
152+ assert len (spec_pos_logprobs ) == 1
153+ (spec_pos_logprob_token_id ,
154+ spec_pos_logprob ) = next (iter (spec_pos_logprobs .items ()))
155+ assert spec_pos_logprob .rank == - 1
156+ assert spec_pos_logprob .logprob == 0.0
157+ assert spec_pos_logprob_token_id in baseline_pos_logprobs
123158
124159
125160def run_equality_correctness_test (
@@ -135,7 +170,10 @@ def run_equality_correctness_test(
135170 disable_seed : bool = False ,
136171 ignore_eos : bool = True ,
137172 ensure_all_accepted : bool = False ,
138- expected_acceptance_rate : Optional [float ] = None ):
173+ expected_acceptance_rate : Optional [float ] = None ,
174+ logprobs : Optional [int ] = None ,
175+ prompt_logprobs : Optional [int ] = None ,
176+ disable_logprobs : bool = False ):
139177
140178 org_args = {
141179 ** common_llm_kwargs ,
@@ -157,10 +195,12 @@ def run_equality_correctness_test(
157195 sampling_params = SamplingParams (temperature = temperature ,
158196 max_tokens = max_output_len ,
159197 seed = seed ,
160- ignore_eos = ignore_eos )
198+ ignore_eos = ignore_eos ,
199+ logprobs = logprobs ,
200+ prompt_logprobs = prompt_logprobs )
161201
162202 with vllm_runner (** org_args ) as vllm_model :
163- org_outputs = vllm_model .generate (prompts , sampling_params )
203+ org_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
164204
165205 with vllm_runner (** sd_args ) as vllm_model :
166206 if ensure_all_accepted or expected_acceptance_rate is not None :
@@ -169,7 +209,7 @@ def run_equality_correctness_test(
169209 'prometheus' ]
170210 stat_logger .local_interval = - 100
171211
172- sd_outputs = vllm_model .generate (prompts , sampling_params )
212+ sd_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
173213
174214 if ensure_all_accepted or expected_acceptance_rate is not None :
175215 acceptance_rate = (stat_logger .metrics .
@@ -185,11 +225,16 @@ def run_equality_correctness_test(
185225 if expected_acceptance_rate is not None :
186226 assert acceptance_rate >= expected_acceptance_rate - 1e-2
187227
188- check_outputs_equal (outputs_0_lst = org_outputs ,
189- outputs_1_lst = sd_outputs ,
228+ # Only pass token entries, not the logprobs
229+ check_outputs_equal (outputs_0_lst = [out [0 :2 ] for out in org_outputs ],
230+ outputs_1_lst = [out [0 :2 ] for out in sd_outputs ],
190231 name_0 = "org" ,
191232 name_1 = "sd" )
192233
234+ # Check logprobs if requested
235+ if logprobs is not None or prompt_logprobs is not None :
236+ check_logprobs_correctness (sd_outputs , org_outputs , disable_logprobs )
237+
193238
194239def run_equality_correctness_test_tp (model ,
195240 common_llm_kwargs ,
0 commit comments