From 894ea25e5837bc7d6e5ffad73a689453e66b1722 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 30 Jul 2024 15:26:56 -0700 Subject: [PATCH 1/6] spec timers --- vllm/spec_decode/spec_decode_worker.py | 47 ++++++++++++++++++++------ 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 98960b88f719..0b60de8d3a25 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,6 +1,7 @@ from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple +import time import torch @@ -522,28 +523,36 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None - # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - execute_model_req, self._seq_with_bonus_token_in_last_step) + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") + + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) - proposal_scores = self.scorer.score_proposals( - execute_model_req, - proposals, - ) - accepted_token_ids, target_logprobs = self._verify_tokens( - execute_model_req.seq_group_metadata_list, proposal_scores, - proposals, execute_model_req.num_lookahead_slots) + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, + scoring_timer.elapsed_time_ms, verification_timer.elapsed_time_ms) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, - k=execute_model_req.num_lookahead_slots) + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -648,6 +657,7 @@ def _create_output_sampler_list( accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, + stage_times: Tuple[float, float, float], ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -725,6 +735,11 @@ def _create_output_sampler_list( if maybe_rejsample_metrics is not None: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + (average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) = stage_times + logger.info( + f"SpecDecodeWorker stage times: {average_time_per_proposal_tok_ms=:.02f} {scoring_time_ms=:.02f} {verification_time_ms=:.02f}") return sampler_output_list def _create_dummy_logprob_lists( @@ -912,3 +927,13 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) return new_num_gpu_blocks + +class Timer: + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000 From f1f993c86e6f59027f9eef7b76f6d83f943e6995 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 30 Jul 2024 16:15:10 -0700 Subject: [PATCH 2/6] lint --- vllm/spec_decode/spec_decode_worker.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 0b60de8d3a25..8781c32ef3dd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,7 +1,7 @@ +import time from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple -import time import torch @@ -532,20 +532,21 @@ def _run_speculative_decoding_step( #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") - + with Timer() as scoring_timer: proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) - + with Timer() as verification_timer: accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, - scoring_timer.elapsed_time_ms, verification_timer.elapsed_time_ms) + scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, @@ -737,9 +738,13 @@ def _create_output_sampler_list( 0].spec_decode_worker_metrics = maybe_rejsample_metrics (average_time_per_proposal_tok_ms, scoring_time_ms, - verification_time_ms) = stage_times + verification_time_ms) = stage_times logger.info( - f"SpecDecodeWorker stage times: {average_time_per_proposal_tok_ms=:.02f} {scoring_time_ms=:.02f} {verification_time_ms=:.02f}") + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) return sampler_output_list def _create_dummy_logprob_lists( @@ -928,7 +933,9 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, return new_num_gpu_blocks + class Timer: + def __enter__(self): self.start_time = time.time() return self From a0fd1341ae61c79ed75ed84a539decba7f2d27d0 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 30 Jul 2024 16:28:17 -0700 Subject: [PATCH 3/6] PR feedback --- vllm/spec_decode/spec_decode_worker.py | 15 +-------------- vllm/spec_decode/util.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8781c32ef3dd..e20e1bccb1f1 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,4 +1,3 @@ -import time from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple @@ -28,7 +27,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -932,15 +931,3 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) return new_num_gpu_blocks - - -class Timer: - - def __enter__(self): - self.start_time = time.time() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.end_time = time.time() - self.elapsed_time_s = self.end_time - self.start_time - self.elapsed_time_ms = self.elapsed_time_s * 1000 diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index ade546eef264..c6223a97dba1 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,3 +1,4 @@ +import time from contextlib import contextmanager from typing import Dict, List, Optional, Tuple @@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs): yield finally: torch.cuda.nvtx.range_pop() + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000 From 09878a6b56e07850b52d83d01e58ede3dc7409ad Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 30 Jul 2024 16:32:56 -0700 Subject: [PATCH 4/6] comment --- vllm/spec_decode/spec_decode_worker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e20e1bccb1f1..fd76dbc76733 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -738,6 +738,10 @@ def _create_output_sampler_list( (average_time_per_proposal_tok_ms, scoring_time_ms, verification_time_ms) = stage_times + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. logger.info( "SpecDecodeWorker stage times: " "average_time_per_proposal_tok_ms=%.02f " From 7469d02e29ea14d23a0b75cceed604c388fd9919 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 4 Aug 2024 10:18:37 -0700 Subject: [PATCH 5/6] disable log stats --- vllm/config.py | 8 +++++- vllm/engine/arg_utils.py | 1 + vllm/spec_decode/spec_decode_worker.py | 37 +++++++++++++++++++------- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 92fde449b43f..e08f0b7729b7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -899,6 +899,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + disable_log_stats: bool, speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1087,7 +1088,8 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha, - disable_logprobs=disable_logprobs + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, ) @staticmethod @@ -1173,6 +1175,7 @@ def __init__( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ): """Create a SpeculativeConfig object. @@ -1205,6 +1208,8 @@ def __init__( sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be returned. + disable_log_stats: Whether to disable periodic printing of stage + times in speculative decoding. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1219,6 +1224,7 @@ def __init__( self.typical_acceptance_sampler_posterior_alpha = \ typical_acceptance_sampler_posterior_alpha self.disable_logprobs = disable_logprobs + self.disable_log_stats = disable_log_stats self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bad5be491721..74d075739685 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -792,6 +792,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, draft_token_acceptance_method=\ diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index fd76dbc76733..f61fed60640a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha, - disable_logprobs=speculative_config.disable_logprobs) + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + ) return spec_decode_worker @@ -116,6 +118,7 @@ def create_worker( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True @@ -171,6 +174,7 @@ def create_worker( proposer_worker, scorer_worker, disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step) @@ -181,6 +185,7 @@ def __init__( scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, disable_logprobs: bool, + disable_log_stats: bool, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, @@ -203,6 +208,8 @@ def __init__( disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -237,6 +244,7 @@ def __init__( # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -736,20 +744,29 @@ def _create_output_sampler_list( sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics - (average_time_per_proposal_tok_ms, scoring_time_ms, - verification_time_ms) = stage_times - # Log time spent in each stage periodically. # This is periodic because the rejection sampler emits metrics # periodically. - logger.info( - "SpecDecodeWorker stage times: " - "average_time_per_proposal_tok_ms=%.02f " - "scoring_time_ms=%.02f verification_time_ms=%.02f", - average_time_per_proposal_tok_ms, scoring_time_ms, - verification_time_ms) + self._maybe_log_stage_times(*stage_times) + return sampler_output_list + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + def _create_dummy_logprob_lists( self, batch_size: int, From e691674d3f0f966597180ebccc6f6b861f6fa75c Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 4 Aug 2024 19:44:54 -0700 Subject: [PATCH 6/6] fix --- tests/spec_decode/test_spec_decode_worker.py | 68 ++++++++++++++------ vllm/spec_decode/spec_decode_worker.py | 4 +- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 671c9bef294f..9ae1b4bc40f0 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() vocab_size = 32_000 @@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int, set_random_seed(1) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str): spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - False, metrics_collector) + worker = SpecDecodeWorker( + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector, + ) worker.init_device() draft_worker.init_device.assert_called_once() @@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method): target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + metrics_collector=metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens(): seq_group_metadata_list=seq_group_metadata_list, accepted_token_ids=accepted_token_ids, target_logprobs=target_token_logprobs, - k=k) + k=k, + stage_times=(0, 0, 0)) # Verify that _seq_with_bonus_token_in_last_step contains the following: # 1. Sequence IDs that were already present in # _seq_with_bonus_token_in_last_step but were not part of the current diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f61fed60640a..95934cf3a37e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -184,8 +184,8 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, - disable_logprobs: bool, - disable_log_stats: bool, + disable_logprobs: bool = False, + disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True,