From 16479565ca41857542ecbf6ec8da28c0a4dc4ad0 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Wed, 12 Mar 2025 14:10:58 -0400 Subject: [PATCH 1/3] [WIP][V1][Metrics] Speculative decoding metrics Fixes #13990, part of #10582 Omitting system efficiency for now. Signed-off-by: Mark McLoughlin --- vllm/v1/core/sched/scheduler.py | 19 +++++++- vllm/v1/engine/async_llm.py | 3 +- vllm/v1/metrics/loggers.py | 43 +++++++++++++++++- vllm/v1/metrics/stats.py | 4 ++ vllm/v1/spec_decode/metrics.py | 72 ++++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 8 ++-- 6 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 vllm/v1/spec_decode/metrics.py diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index aafa2f0a9f30..ec6d6071976a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -23,6 +23,7 @@ from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) @@ -567,6 +568,7 @@ def update_from_output( spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + spec_decoding_stats = SpecDecodingStats() if self.log_stats else None num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] @@ -599,6 +601,15 @@ def update_from_output( len(generated_token_ids)) request.num_computed_tokens -= num_tokens_rejected + if spec_decoding_stats is not None: + # FIXME: If a drafter proposes zero tokens, we should + # treat this as if num_spec_tokens were proposed and + # all rejected to allow fair comparisons between drafters + spec_decoding_stats.observe( + num_draft_tokens=len(scheduled_spec_token_ids), + num_accepted_tokens=len(generated_token_ids) - 1, + num_emitted_tokens=len(generated_token_ids)) + cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) # OPTIMIZATION: Avoid list(set) if the set is empty. @@ -672,7 +683,7 @@ def update_from_output( self.running = new_running engine_core_outputs = EngineCoreOutputs( outputs=outputs, - scheduler_stats=self.make_stats(), + scheduler_stats=self.make_stats(spec_decoding_stats), ) if self.include_finished_set: #TODO currently sending duplicates here, improve this @@ -739,7 +750,10 @@ def get_num_unscheduled_requests(self) -> int: def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() - def make_stats(self) -> Optional[SchedulerStats]: + def make_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats] = None, + ) -> Optional[SchedulerStats]: if not self.log_stats: return None return SchedulerStats( @@ -747,4 +761,5 @@ def make_stats(self) -> Optional[SchedulerStats]: num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), + spec_decoding_stats=spec_decoding_stats, ) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a8d86e70f6ab..4ac939ab5f1e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -74,7 +74,8 @@ def __init__( for i in range(vllm_config.parallel_config.data_parallel_size): loggers: list[StatLoggerBase] = [] if logger.isEnabledFor(logging.INFO): - loggers.append(LoggingStatLogger(engine_index=i)) + loggers.append( + LoggingStatLogger(vllm_config, engine_index=i)) loggers.append( PrometheusStatLogger(vllm_config, engine_index=i)) self.stat_loggers.append(loggers) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 6ffd00ebd17a..efca6731367b 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,6 +12,7 @@ from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.spec_decode.metrics import SpecDecodingMetrics logger = init_logger(__name__) @@ -31,13 +32,15 @@ def log(self): # noqa class LoggingStatLogger(StatLoggerBase): - def __init__(self, engine_index: int = 0): + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() # Prefix cache metrics. This cannot be reset. # TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() + self.spec_decoding_metrics = SpecDecodingMetrics( + vllm_config.speculative_config) def _reset(self, now): self.last_log_time = now @@ -65,6 +68,10 @@ def record(self, scheduler_stats: SchedulerStats, self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_metrics.observe( + scheduler_stats.spec_decoding_stats) + self.last_scheduler_stats = scheduler_stats def log(self): @@ -94,6 +101,9 @@ def log(self): self.prefix_caching_metrics.hit_rate * 100, ) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_metrics.log() + class PrometheusStatLogger(StatLoggerBase): @@ -302,6 +312,29 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.labelname_running_lora_adapters, ]) + # + # Speculative Decoding metrics + # The acceptance rate can be calculated using a PromQL query: + # + # rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + # rate(vllm:spec_decode_num_draft_tokens_total[$interval]) + # + self.counter_spec_decode_num_draft_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_accepted_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_emitted_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames).labels(*labelvalues) + # # Cache config info metric # @@ -338,6 +371,14 @@ def record(self, scheduler_stats: SchedulerStats, self.counter_gpu_prefix_cache_hits.inc( scheduler_stats.prefix_cache_stats.hits) + if scheduler_stats.spec_decoding_stats is not None: + self.counter_spec_decode_num_draft_tokens.inc( + scheduler_stats.spec_decoding_stats.num_draft_tokens) + self.counter_spec_decode_num_accepted_tokens.inc( + scheduler_stats.spec_decoding_stats.num_accepted_tokens) + self.counter_spec_decode_num_emitted_tokens.inc( + scheduler_stats.spec_decoding_stats.num_emitted_tokens) + if iteration_stats is None: return diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 6f3d34447426..fd949264885b 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -4,6 +4,8 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional +from vllm.v1.spec_decode.metrics import SpecDecodingStats + if TYPE_CHECKING: from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine.output_processor import RequestState @@ -35,6 +37,8 @@ class SchedulerStats: prefix_cache_stats: PrefixCacheStats = field( default_factory=PrefixCacheStats) + spec_decoding_stats: Optional[SpecDecodingStats] = None + @dataclass class LoRAStats: diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py new file mode 100644 index 000000000000..d419466aff41 --- /dev/null +++ b/vllm/v1/spec_decode/metrics.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import numpy as np + +from vllm.config import SpeculativeConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class SpecDecodingStats: + num_draft_tokens: int = 0 + num_accepted_tokens: int = 0 + num_emitted_tokens: int = 0 + + def take(self): + copied = SpecDecodingStats(self.num_draft_tokens, + self.num_accepted_tokens, + self.num_emitted_tokens) + self.reset() + return copied + + def reset(self): + self.num_draft_tokens = 0 + self.num_accepted_tokens = 0 + self.num_emitted_tokens = 0 + + def observe(self, num_draft_tokens: int, num_accepted_tokens: int, + num_emitted_tokens: int): + self.num_draft_tokens += num_draft_tokens + self.num_accepted_tokens += num_accepted_tokens + self.num_emitted_tokens += num_emitted_tokens + + +class SpecDecodingMetrics: + + def __init__(self, speculative_config: SpeculativeConfig): + self.num_spec_tokens = (speculative_config.num_speculative_tokens + if speculative_config is not None else 0) + self.reset() + + def reset(self): + self.num_draft_tokens: list[int] = [] + self.num_accepted_tokens: list[int] = [] + self.num_emitted_tokens: list[int] = [] + + def observe(self, spec_decoding_stats: SpecDecodingStats): + self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) + self.num_accepted_tokens.append( + spec_decoding_stats.num_accepted_tokens) + self.num_emitted_tokens.append(spec_decoding_stats.num_emitted_tokens) + + def log(self): + num_draft_tokens = np.sum(self.num_draft_tokens) + num_accepted_tokens = np.sum(self.num_accepted_tokens) + num_emitted_tokens = np.sum(self.num_emitted_tokens) + + draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens + if num_draft_tokens > 0 else float("nan")) + + logger.info( + "Speculative metrics: " + "Draft acceptance rate: %.3f, " + "Number of speculative tokens: %d, " + "Number of accepted tokens: %d, " + "Number of draft tokens: %d, " + "Number of emitted tokens: %d.", draft_acceptance_rate, + num_accepted_tokens, num_draft_tokens, num_emitted_tokens) + self.reset() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 74f3124e3c77..ee2e830ce3a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1154,20 +1154,20 @@ def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata, - ) -> list[list[int]]: + ) -> list[Optional[list[int]]]: # TODO(woosuk): Optimize. - draft_token_ids: list[list[int]] = [] + draft_token_ids: list[Optional[list[int]]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. - draft_token_ids.append([]) + draft_token_ids.append(None) continue # Skip requests that require top-p, top-k, etc. req_id = self.input_batch.req_ids[i] if not is_spec_decode_supported(req_id, self.input_batch): - draft_token_ids.append([]) + draft_token_ids.append(None) continue # Add sampled_token_ids to token_ids_cpu. From 85ce056523cfb08cbe0cc80689188c74d3624c34 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 31 Mar 2025 07:13:57 -0400 Subject: [PATCH 2/3] [V1][Spec Decoding] Strip metrics back to acceptance rate Now just num_accepted_tokens, num_draft_tokens, and acceptance rate. Signed-off-by: Mark McLoughlin --- vllm/v1/core/sched/scheduler.py | 6 +----- vllm/v1/engine/async_llm.py | 3 +-- vllm/v1/metrics/loggers.py | 12 ++---------- vllm/v1/spec_decode/metrics.py | 23 +++++------------------ vllm/v1/worker/gpu_model_runner.py | 8 ++++---- 5 files changed, 13 insertions(+), 39 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ec6d6071976a..8b9c59b9c212 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -602,13 +602,9 @@ def update_from_output( request.num_computed_tokens -= num_tokens_rejected if spec_decoding_stats is not None: - # FIXME: If a drafter proposes zero tokens, we should - # treat this as if num_spec_tokens were proposed and - # all rejected to allow fair comparisons between drafters spec_decoding_stats.observe( num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1, - num_emitted_tokens=len(generated_token_ids)) + num_accepted_tokens=len(generated_token_ids) - 1) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4ac939ab5f1e..a8d86e70f6ab 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -74,8 +74,7 @@ def __init__( for i in range(vllm_config.parallel_config.data_parallel_size): loggers: list[StatLoggerBase] = [] if logger.isEnabledFor(logging.INFO): - loggers.append( - LoggingStatLogger(vllm_config, engine_index=i)) + loggers.append(LoggingStatLogger(engine_index=i)) loggers.append( PrometheusStatLogger(vllm_config, engine_index=i)) self.stat_loggers.append(loggers) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index efca6731367b..73883d9a735d 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -32,15 +32,14 @@ def log(self): # noqa class LoggingStatLogger(StatLoggerBase): - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + def __init__(self, engine_index: int = 0): self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() # Prefix cache metrics. This cannot be reset. # TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() - self.spec_decoding_metrics = SpecDecodingMetrics( - vllm_config.speculative_config) + self.spec_decoding_metrics = SpecDecodingMetrics() def _reset(self, now): self.last_log_time = now @@ -329,11 +328,6 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): name="vllm:spec_decode_num_accepted_tokens_total", documentation="Number of accepted tokens.", labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_emitted_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_emitted_tokens_total", - documentation="Number of emitted tokens.", - labelnames=labelnames).labels(*labelvalues) # # Cache config info metric @@ -376,8 +370,6 @@ def record(self, scheduler_stats: SchedulerStats, scheduler_stats.spec_decoding_stats.num_draft_tokens) self.counter_spec_decode_num_accepted_tokens.inc( scheduler_stats.spec_decoding_stats.num_accepted_tokens) - self.counter_spec_decode_num_emitted_tokens.inc( - scheduler_stats.spec_decoding_stats.num_emitted_tokens) if iteration_stats is None: return diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index d419466aff41..7fecbaeed4f7 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -4,7 +4,6 @@ import numpy as np -from vllm.config import SpeculativeConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -14,49 +13,39 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 - num_emitted_tokens: int = 0 def take(self): copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens, - self.num_emitted_tokens) + self.num_accepted_tokens) self.reset() return copied def reset(self): self.num_draft_tokens = 0 self.num_accepted_tokens = 0 - self.num_emitted_tokens = 0 - def observe(self, num_draft_tokens: int, num_accepted_tokens: int, - num_emitted_tokens: int): + def observe(self, num_draft_tokens: int, num_accepted_tokens: int): self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens - self.num_emitted_tokens += num_emitted_tokens class SpecDecodingMetrics: - def __init__(self, speculative_config: SpeculativeConfig): - self.num_spec_tokens = (speculative_config.num_speculative_tokens - if speculative_config is not None else 0) + def __init__(self): self.reset() def reset(self): self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] - self.num_emitted_tokens: list[int] = [] def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) self.num_accepted_tokens.append( spec_decoding_stats.num_accepted_tokens) - self.num_emitted_tokens.append(spec_decoding_stats.num_emitted_tokens) def log(self): num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) - num_emitted_tokens = np.sum(self.num_emitted_tokens) draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens if num_draft_tokens > 0 else float("nan")) @@ -64,9 +53,7 @@ def log(self): logger.info( "Speculative metrics: " "Draft acceptance rate: %.3f, " - "Number of speculative tokens: %d, " "Number of accepted tokens: %d, " - "Number of draft tokens: %d, " - "Number of emitted tokens: %d.", draft_acceptance_rate, - num_accepted_tokens, num_draft_tokens, num_emitted_tokens) + "Number of draft tokens: %d, ", draft_acceptance_rate, + num_accepted_tokens, num_draft_tokens) self.reset() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee2e830ce3a8..74f3124e3c77 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1154,20 +1154,20 @@ def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata, - ) -> list[Optional[list[int]]]: + ) -> list[list[int]]: # TODO(woosuk): Optimize. - draft_token_ids: list[Optional[list[int]]] = [] + draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. - draft_token_ids.append(None) + draft_token_ids.append([]) continue # Skip requests that require top-p, top-k, etc. req_id = self.input_batch.req_ids[i] if not is_spec_decode_supported(req_id, self.input_batch): - draft_token_ids.append(None) + draft_token_ids.append([]) continue # Add sampled_token_ids to token_ids_cpu. From 840f4ce72814fa7a6ae7602668e34ef17f23689f Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 31 Mar 2025 09:20:47 -0400 Subject: [PATCH 3/3] [V1][Spec Decoding] Add scheduler test cases Signed-off-by: Mark McLoughlin --- tests/v1/core/test_scheduler.py | 95 +++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 24a51288cbb9..5770afa2ea70 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -600,3 +600,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], prompt_logprobs_dict={}, ) scheduler.update_from_output(scheduler_output1, model_runner_output) + + +# Note - these test cases mirror some of those in test_rejection_sampler.py +@pytest.mark.parametrize( + "spec_tokens,output_tokens,expected", + [ + ([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match + ([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences + ([[1]], [[1, 2]], (1, 1)), # single token sequence + ([[]], [[5]], (0, 0)), # empty sequence + ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], + (6, 3)), # multiple mismatches + ]) +def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): + """Test scheduling behavior with speculative decoding. + + This test verifies that: + 1. Speculated tokens get scheduled correctly + 2. Spec decoding stats properly count number of draft and accepted tokens + """ + scheduler = create_scheduler() + requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + # Schedule a decode, which will also draft speculative tokens + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert output.total_num_scheduled_tokens == len(requests) + for i in range(len(requests)): + req_id = requests[i].request_id + assert output.num_scheduled_tokens[req_id] == 1 + assert req_id not in output.scheduled_spec_decode_tokens + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=spec_tokens, + logprobs=None, + prompt_logprobs_dict={}, + ) + engine_core_outputs = scheduler.update_from_output(output, + model_runner_output) + + for i in range(len(requests)): + running_req = scheduler.running[i] + # The prompt token + assert running_req.num_computed_tokens == 1 + # The prompt token and the sampled token + assert running_req.num_tokens == 2 + # The prompt token, the sampled token, and the speculated tokens + assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) + + # No draft or accepted tokens counted yet + assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None + stats = engine_core_outputs.scheduler_stats.spec_decoding_stats + assert stats.num_draft_tokens == 0 + assert stats.num_accepted_tokens == 0 + + # Schedule the speculated tokens for validation + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 0 + # The sampled token and speculated tokens + assert output.total_num_scheduled_tokens == \ + len(requests) + sum(len(ids) for ids in spec_tokens) + for i in range(len(requests)): + req_id = requests[i].request_id + assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i]) + if spec_tokens[i]: + assert len(output.scheduled_spec_decode_tokens[req_id]) == \ + len(spec_tokens[i]) + else: + assert req_id not in output.scheduled_spec_decode_tokens + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=output_tokens, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + engine_core_outputs = scheduler.update_from_output(output, + model_runner_output) + + assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None + stats = engine_core_outputs.scheduler_stats.spec_decoding_stats + assert stats.num_draft_tokens == expected[0] + assert stats.num_accepted_tokens == expected[1]