Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs_dict={},
)
scheduler.update_from_output(scheduler_output1, model_runner_output)


# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
([[1]], [[1, 2]], (1, 1)), # single token sequence
([[]], [[5]], (0, 0)), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(6, 3)), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.

This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
scheduler = create_scheduler()
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i

# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens

model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)

for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])

# No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == 0
assert stats.num_accepted_tokens == 0

# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == \
len(requests) + sum(len(ids) for ids in spec_tokens)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
len(spec_tokens[i])
else:
assert req_id not in output.scheduled_spec_decode_tokens

model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)

assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
assert stats.num_draft_tokens == expected[0]
assert stats.num_accepted_tokens == expected[1]
15 changes: 13 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager

logger = init_logger(__name__)
Expand Down Expand Up @@ -567,6 +568,7 @@ def update_from_output(
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
spec_decoding_stats = SpecDecodingStats() if self.log_stats else None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens

new_running: list[Request] = []
Expand Down Expand Up @@ -599,6 +601,11 @@ def update_from_output(
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected

if spec_decoding_stats is not None:
spec_decoding_stats.observe(
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)

cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
Expand Down Expand Up @@ -672,7 +679,7 @@ def update_from_output(
self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
scheduler_stats=self.make_stats(spec_decoding_stats),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
Expand Down Expand Up @@ -739,12 +746,16 @@ def get_num_unscheduled_requests(self) -> int:
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()

def make_stats(self) -> Optional[SchedulerStats]:
def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
spec_decoding_stats=spec_decoding_stats,
)
33 changes: 33 additions & 0 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics

logger = init_logger(__name__)

Expand All @@ -38,6 +39,7 @@ def __init__(self, engine_index: int = 0):
# Prefix cache metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_metrics = SpecDecodingMetrics()

def _reset(self, now):
self.last_log_time = now
Expand Down Expand Up @@ -65,6 +67,10 @@ def record(self, scheduler_stats: SchedulerStats,

self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)

if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.observe(
scheduler_stats.spec_decoding_stats)

self.last_scheduler_stats = scheduler_stats

def log(self):
Expand Down Expand Up @@ -94,6 +100,9 @@ def log(self):
self.prefix_caching_metrics.hit_rate * 100,
)

if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_metrics.log()


class PrometheusStatLogger(StatLoggerBase):

Expand Down Expand Up @@ -302,6 +311,24 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.labelname_running_lora_adapters,
])

#
# Speculative Decoding metrics
# The acceptance rate can be calculated using a PromQL query:
#
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
#
self.counter_spec_decode_num_draft_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \
prometheus_client.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues)

#
# Cache config info metric
#
Expand Down Expand Up @@ -338,6 +365,12 @@ def record(self, scheduler_stats: SchedulerStats,
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)

if scheduler_stats.spec_decoding_stats is not None:
self.counter_spec_decode_num_draft_tokens.inc(
scheduler_stats.spec_decoding_stats.num_draft_tokens)
self.counter_spec_decode_num_accepted_tokens.inc(
scheduler_stats.spec_decoding_stats.num_accepted_tokens)

if iteration_stats is None:
return

Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional

from vllm.v1.spec_decode.metrics import SpecDecodingStats

if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
Expand Down Expand Up @@ -35,6 +37,8 @@ class SchedulerStats:
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)

spec_decoding_stats: Optional[SpecDecodingStats] = None


@dataclass
class LoRAStats:
Expand Down
59 changes: 59 additions & 0 deletions vllm/v1/spec_decode/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import numpy as np

from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclass
class SpecDecodingStats:
num_draft_tokens: int = 0
num_accepted_tokens: int = 0

def take(self):
copied = SpecDecodingStats(self.num_draft_tokens,
self.num_accepted_tokens)
self.reset()
return copied

def reset(self):
self.num_draft_tokens = 0
self.num_accepted_tokens = 0

def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens


class SpecDecodingMetrics:

def __init__(self):
self.reset()

def reset(self):
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []

def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(
spec_decoding_stats.num_accepted_tokens)

def log(self):
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)

draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
if num_draft_tokens > 0 else float("nan"))

logger.info(
"Speculative metrics: "
"Draft acceptance rate: %.3f, "
"Number of accepted tokens: %d, "
"Number of draft tokens: %d, ", draft_acceptance_rate,
num_accepted_tokens, num_draft_tokens)
self.reset()