2727from vllm .spec_decode .proposer_worker_base import ProposerWorkerBase
2828from vllm .spec_decode .smaller_tp_proposer_worker import SmallerTpProposerWorker
2929from vllm .spec_decode .target_model_runner import TargetModelRunner
30- from vllm .spec_decode .util import (create_sequence_group_output ,
30+ from vllm .spec_decode .util import (Timer , create_sequence_group_output ,
3131 get_all_num_logprobs ,
3232 get_sampled_token_logprobs , nvtx_range ,
3333 split_batch_by_proposal_len )
@@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
7575 typical_acceptance_sampler_posterior_threshold ,
7676 typical_acceptance_sampler_posterior_alpha = speculative_config .
7777 typical_acceptance_sampler_posterior_alpha ,
78- disable_logprobs = speculative_config .disable_logprobs )
78+ disable_logprobs = speculative_config .disable_logprobs ,
79+ disable_log_stats = speculative_config .disable_log_stats ,
80+ )
7981
8082 return spec_decode_worker
8183
@@ -116,6 +118,7 @@ def create_worker(
116118 typical_acceptance_sampler_posterior_threshold : float ,
117119 typical_acceptance_sampler_posterior_alpha : float ,
118120 disable_logprobs : bool ,
121+ disable_log_stats : bool ,
119122 ) -> "SpecDecodeWorker" :
120123
121124 allow_zero_draft_token_step = True
@@ -171,6 +174,7 @@ def create_worker(
171174 proposer_worker ,
172175 scorer_worker ,
173176 disable_logprobs = disable_logprobs ,
177+ disable_log_stats = disable_log_stats ,
174178 disable_by_batch_size = disable_by_batch_size ,
175179 spec_decode_sampler = spec_decode_sampler ,
176180 allow_zero_draft_token_step = allow_zero_draft_token_step )
@@ -180,7 +184,8 @@ def __init__(
180184 proposer_worker : ProposerWorkerBase ,
181185 scorer_worker : WorkerBase ,
182186 spec_decode_sampler : SpecDecodeBaseSampler ,
183- disable_logprobs : bool ,
187+ disable_logprobs : bool = False ,
188+ disable_log_stats : bool = False ,
184189 metrics_collector : Optional [AsyncMetricsCollector ] = None ,
185190 disable_by_batch_size : Optional [int ] = None ,
186191 allow_zero_draft_token_step : Optional [bool ] = True ,
@@ -203,6 +208,8 @@ def __init__(
203208 disable_logprobs: If set to True, token log probabilities will
204209 not be output in both the draft worker and the target worker.
205210 If set to False, log probabilities will be output by both.
211+ disable_log_stats: If set to True, disable periodic printing of
212+ speculative stage times.
206213 disable_by_batch_size: If the batch size is larger than this,
207214 disable speculative decoding for new incoming requests.
208215 metrics_collector: Helper class for collecting metrics; can be set
@@ -240,6 +247,7 @@ def __init__(
240247 # in the subsequent step.
241248 self .previous_hidden_states : Optional [HiddenStates ] = None
242249 self ._disable_logprobs = disable_logprobs
250+ self ._disable_log_stats = disable_log_stats
243251
244252 def init_device (self ) -> None :
245253 """Initialize both scorer and proposer models.
@@ -525,28 +533,37 @@ def _run_speculative_decoding_step(
525533 execute_model_req .previous_hidden_states = self .previous_hidden_states
526534 self .previous_hidden_states = None
527535
528- # Generate proposals using draft worker.
529- proposals = self .proposer_worker .get_spec_proposals (
530- execute_model_req , self ._seq_with_bonus_token_in_last_step )
536+ with Timer () as proposal_timer :
537+ # Generate proposals using draft worker.
538+ proposals = self .proposer_worker .get_spec_proposals (
539+ execute_model_req , self ._seq_with_bonus_token_in_last_step )
531540
532541 if not self ._allow_zero_draft_token_step and proposals .no_proposals :
533542 #TODO: Fix it #5814
534543 raise RuntimeError ("Cannot handle cases where distributed draft "
535544 "workers generate no tokens" )
536545
537- proposal_scores = self .scorer .score_proposals (
538- execute_model_req ,
539- proposals ,
540- )
541- accepted_token_ids , target_logprobs = self ._verify_tokens (
542- execute_model_req .seq_group_metadata_list , proposal_scores ,
543- proposals , execute_model_req .num_lookahead_slots )
546+ with Timer () as scoring_timer :
547+ proposal_scores = self .scorer .score_proposals (
548+ execute_model_req ,
549+ proposals ,
550+ )
551+
552+ with Timer () as verification_timer :
553+ accepted_token_ids , target_logprobs = self ._verify_tokens (
554+ execute_model_req .seq_group_metadata_list , proposal_scores ,
555+ proposals , execute_model_req .num_lookahead_slots )
556+
557+ stage_times = (proposal_timer .elapsed_time_ms / num_lookahead_slots ,
558+ scoring_timer .elapsed_time_ms ,
559+ verification_timer .elapsed_time_ms )
544560
545561 return self ._create_output_sampler_list (
546562 execute_model_req .seq_group_metadata_list ,
547563 accepted_token_ids ,
548564 target_logprobs = target_logprobs ,
549- k = execute_model_req .num_lookahead_slots )
565+ k = execute_model_req .num_lookahead_slots ,
566+ stage_times = stage_times )
550567
551568 @nvtx_range ("spec_decode_worker._verify_tokens" )
552569 def _verify_tokens (
@@ -645,6 +662,7 @@ def _create_output_sampler_list(
645662 accepted_token_ids : torch .Tensor , # shape: [batch_size, k+1]
646663 target_logprobs : torch .Tensor , # shape: [batch_size, k+1, vocab_size]
647664 k : int ,
665+ stage_times : Tuple [float , float , float ],
648666 ) -> List [SamplerOutput ]:
649667 """Given the accepted token ids, create a list of SamplerOutput.
650668
@@ -722,8 +740,30 @@ def _create_output_sampler_list(
722740 if maybe_rejsample_metrics is not None :
723741 sampler_output_list [
724742 0 ].spec_decode_worker_metrics = maybe_rejsample_metrics
743+
744+ # Log time spent in each stage periodically.
745+ # This is periodic because the rejection sampler emits metrics
746+ # periodically.
747+ self ._maybe_log_stage_times (* stage_times )
748+
725749 return sampler_output_list
726750
751+ def _maybe_log_stage_times (self , average_time_per_proposal_tok_ms : float ,
752+ scoring_time_ms : float ,
753+ verification_time_ms : float ) -> None :
754+ """Log the speculative stage times. If stat logging is disabled, do
755+ nothing.
756+ """
757+ if self ._disable_log_stats :
758+ return
759+
760+ logger .info (
761+ "SpecDecodeWorker stage times: "
762+ "average_time_per_proposal_tok_ms=%.02f "
763+ "scoring_time_ms=%.02f verification_time_ms=%.02f" ,
764+ average_time_per_proposal_tok_ms , scoring_time_ms ,
765+ verification_time_ms )
766+
727767 def _create_dummy_logprob_lists (
728768 self ,
729769 batch_size : int ,
0 commit comments