From 131e6797f25cd33150366675a1c2e72f6fbc9d58 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Wed, 9 Apr 2025 23:20:59 +0000 Subject: [PATCH 1/4] add request level, per-step acceptance counts tracking for spec dec Signed-off-by: Bryan Lu --- vllm/outputs.py | 4 ++++ vllm/v1/core/sched/scheduler.py | 8 +++++++- vllm/v1/engine/__init__.py | 2 ++ vllm/v1/engine/llm_engine.py | 19 ++++++++++++++----- vllm/v1/engine/output_processor.py | 26 ++++++++++++++++---------- vllm/v1/engine/processor.py | 3 ++- vllm/v1/request.py | 4 +++- 7 files changed, 48 insertions(+), 18 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..661dfe2870a5 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -43,6 +43,7 @@ class CompletionOutput: finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None + spec_token_acceptance_counts: Optional[list[int]] = None def finished(self) -> bool: return self.finish_reason is not None @@ -133,6 +134,9 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.spec_token_acceptance_counts = [ + o.spec_token_acceptance_counts for o in outputs + ] def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 488d32cb82cf..ab4b88eed22c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -601,6 +601,10 @@ def update_from_output( num_draft_tokens=len(scheduled_spec_token_ids), num_accepted_tokens=len(generated_token_ids) - 1) + for i in range(len(generated_token_ids)): + if request.spec_token_acceptance_counts is not None: + request.spec_token_acceptance_counts[i] += 1 + cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) # OPTIMIZATION: Avoid list(set) if the set is empty. @@ -662,7 +666,9 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + spec_token_acceptance_counts=request. + spec_token_acceptance_counts)) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1264e43c79d9..fa7a66951eac 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -60,6 +60,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] + spec_token_acceptance_counts: Optional[list[int]] class EngineCoreEventType(enum.IntEnum): @@ -102,6 +103,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + spec_token_acceptance_counts: Optional[list[int]] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4c67186f7040..345ea9036eb9 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -183,11 +183,20 @@ def add_request( priority: int = 0, ) -> None: # Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + num_spec_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + trace_headers, + prompt_adapter_request, + priority, + num_spec_tokens=num_spec_tokens) n = params.n if isinstance(params, SamplingParams) else 1 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 70f072d3c939..b040f42f8e74 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -136,10 +136,9 @@ def from_new_request( ) def make_request_output( - self, - new_token_ids: list[int], - finish_reason: Optional[FinishReason], + self, new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]] ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -150,7 +149,10 @@ def make_request_output( return None completion_output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason) + new_token_ids, + finish_reason, + stop_reason, + spec_token_acceptance_counts=spec_token_acceptance_counts) request_id = self.request_id if self.parent_req is None: @@ -186,10 +188,9 @@ def _new_request_output( ) def _new_completion_output( - self, - token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], + self, token_ids: list[int], finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]] ) -> CompletionOutput: finished = finish_reason is not None @@ -212,7 +213,8 @@ def _new_completion_output( logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None) + stop_reason=stop_reason if finished else None, + spec_token_acceptance_counts=spec_token_acceptance_counts) class OutputProcessor: @@ -337,7 +339,11 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, + finish_reason, + stop_reason, + spec_token_acceptance_counts=engine_core_output. + spec_token_acceptance_counts): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7d1913ecebed..468401dfefea 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -177,6 +177,7 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + num_spec_tokens: int = 0, ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. @@ -292,7 +293,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - ) + spec_token_acceptance_counts=[0] * (num_spec_tokens + 1)) def _validate_model_inputs(self, inputs: ProcessorInputs, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde5..1eed5427eddc 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -30,6 +30,7 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, + spec_token_acceptance_counts: Optional[list[int]] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -53,6 +54,7 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 + self.spec_token_acceptance_counts = spec_token_acceptance_counts # Multi-modal related self.mm_positions = multi_modal_placeholders or [] @@ -92,7 +94,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), - ) + spec_token_acceptance_counts=request.spec_token_acceptance_counts) def append_output_token_ids( self, From d21afbf21871f67a4e228f92bedc4c5ee26593df Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Sat, 12 Apr 2025 00:35:18 +0000 Subject: [PATCH 2/4] rebase Signed-off-by: Bryan Lu --- examples/offline_inference/eagle.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56f..bb9993448aa0 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,8 +45,12 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--use_v1", type=str, default="1", help='1 or 0') args = parser.parse_args() + # TODO: remove this option once EAGLE in v1 is ready. + os.environ["VLLM_USE_V1"] = args.use_v1 + model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" @@ -94,10 +98,16 @@ def main(): # to account for the token from the target model that's always going to be # accepted acceptance_counts = [0] * (args.num_spec_tokens + 1) - for output in outputs: - for step, count in enumerate( - output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count + if args.use_v1 == '1': + for output in outputs: + for step, count in enumerate( + output.spec_token_acceptance_counts[0]): + acceptance_counts[step] += count + else: + for output in outputs: + for step, count in enumerate( + output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count print("-" * 50) print(f"mean acceptance length: \ From ddc1afd83b929b207d6636fccd0ac9c8c59ea2d4 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 14 Apr 2025 06:39:03 +0000 Subject: [PATCH 3/4] update design Signed-off-by: Bryan Lu --- vllm/v1/core/sched/scheduler.py | 15 ++++++--------- vllm/v1/engine/__init__.py | 3 +-- vllm/v1/engine/llm_engine.py | 4 ++-- vllm/v1/engine/output_processor.py | 21 +++++++++++++++++++-- vllm/v1/engine/processor.py | 2 +- vllm/v1/request.py | 4 +--- vllm/v1/spec_decode/metrics.py | 11 ++++++++--- 7 files changed, 38 insertions(+), 22 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5cd7a6980bb3..ccddd3417431 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -608,11 +608,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) - - for i in range(len(generated_token_ids)): - if request.spec_token_acceptance_counts is not None: - request.spec_token_acceptance_counts[i] += 1 + num_accepted_tokens=len(generated_token_ids) - 1, + request_id=req_id) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -675,9 +672,7 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events(), - spec_token_acceptance_counts=request. - spec_token_acceptance_counts)) + events=request.take_events())) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -775,11 +770,13 @@ def make_spec_decoding_stats( spec_decoding_stats: Optional[SpecDecodingStats], num_draft_tokens: int, num_accepted_tokens: int, + request_id: str, ) -> Optional[SpecDecodingStats]: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats() spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_accepted_tokens=num_accepted_tokens, + request_id=request_id) return spec_decoding_stats diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index fa7a66951eac..33a6225009b2 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -60,7 +60,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] - spec_token_acceptance_counts: Optional[list[int]] + num_spec_tokens: int class EngineCoreEventType(enum.IntEnum): @@ -103,7 +103,6 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None - spec_token_acceptance_counts: Optional[list[int]] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 345ea9036eb9..f3285735e57b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -92,7 +92,7 @@ def __init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, # FIXME: implement + log_stats=True, # FIXME: implement ) if not multiprocess_mode: @@ -232,7 +232,7 @@ def step(self) -> list[RequestOutput]: # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( - outputs.outputs) + outputs.outputs, scheduler_stats=outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index b040f42f8e74..c5554bd20f0c 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,7 +14,7 @@ from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats) + RequestStateStats, SchedulerStats) class RequestOutputCollector: @@ -81,6 +81,7 @@ def __init__( arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + num_spec_tokens: int = 0, ): self.request_id = request_id self.parent_req = parent_req @@ -99,6 +100,8 @@ def __init__( self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None + self.spec_token_acceptance_counts = [0] * (num_spec_tokens + 1) + @classmethod def from_new_request( cls, @@ -133,6 +136,7 @@ def from_new_request( arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, + num_spec_tokens=request.num_spec_tokens, ) def make_request_output( @@ -282,6 +286,7 @@ def process_outputs( engine_core_outputs: list[EngineCoreOutput], engine_core_timestamp: Optional[float] = None, iteration_stats: Optional[IterationStats] = None, + scheduler_stats: Optional[SchedulerStats] = None, ) -> OutputProcessorOutput: """ Process the EngineCoreOutputs: @@ -320,6 +325,8 @@ def process_outputs( self._update_stats_from_output(req_state, engine_core_output, engine_core_timestamp, iteration_stats) + self._update_stats_from_scheduler(req_id, req_state, + scheduler_stats) new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason @@ -342,7 +349,7 @@ def process_outputs( new_token_ids, finish_reason, stop_reason, - spec_token_acceptance_counts=engine_core_output. + spec_token_acceptance_counts=req_state. spec_token_acceptance_counts): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). @@ -409,3 +416,13 @@ def _update_stats_from_finished(self, req_state: RequestState, ParentRequest.observe_finished_request( req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens) + + def _update_stats_from_scheduler( + self, req_id: str, req_state: RequestState, + scheduler_stats: Optional[SchedulerStats]): + if scheduler_stats is not None and \ + scheduler_stats.spec_decoding_stats is not None: + num_accepted_tokens = scheduler_stats. \ + spec_decoding_stats.per_request_stats.get(req_id, 0) + for i in range(num_accepted_tokens): + req_state.spec_token_acceptance_counts[i] += 1 diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a362de74ed0a..6488ca6ed4d4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -314,7 +314,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - spec_token_acceptance_counts=[0] * (num_spec_tokens + 1)) + num_spec_tokens=num_spec_tokens) def _validate_model_inputs(self, inputs: ProcessorInputs, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1eed5427eddc..6be72431dde5 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -30,7 +30,6 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, - spec_token_acceptance_counts: Optional[list[int]] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -54,7 +53,6 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 - self.spec_token_acceptance_counts = spec_token_acceptance_counts # Multi-modal related self.mm_positions = multi_modal_placeholders or [] @@ -94,7 +92,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), - spec_token_acceptance_counts=request.spec_token_acceptance_counts) + ) def append_output_token_ids( self, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 7bb3c209d1dc..a44523c945c1 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np @@ -13,20 +13,25 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 + per_request_stats: dict = field(default_factory=dict) def take(self): copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens) + self.num_accepted_tokens, + self.per_request_stats) self.reset() return copied def reset(self): self.num_draft_tokens = 0 self.num_accepted_tokens = 0 + self.per_request_stats = {} - def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + def observe(self, num_draft_tokens: int, num_accepted_tokens: int, + request_id: str): self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens + self.per_request_stats[request_id] = num_accepted_tokens + 1 class SpecDecodingMetrics: From cbb96bca30ccf93b60bca749326e037011435891 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 14 Apr 2025 06:50:33 +0000 Subject: [PATCH 4/4] minor Signed-off-by: Bryan Lu --- vllm/outputs.py | 4 ++++ vllm/v1/engine/output_processor.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 661dfe2870a5..19d8fe08eb6c 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -33,6 +33,10 @@ class CompletionOutput: to stop, None if the completion finished for some other reason including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. + spec_token_acceptance_counts: A list tracking the total number of + accepted tokens at each speculation step for a request. Its length + is num_spec_tokens + 1 since there is always one token generated + by the target model. """ index: int diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index c5554bd20f0c..4d7b86ec951b 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -192,9 +192,11 @@ def _new_request_output( ) def _new_completion_output( - self, token_ids: list[int], finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], - spec_token_acceptance_counts: Optional[list[int]] + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]], ) -> CompletionOutput: finished = finish_reason is not None