Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/contributing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ For tensor-parallel debugging, you can enable an option to redirect all log outp
Set `VLLM_SPYRE_WORKER_LOG_REDIRECT_DIR` to a local directory, and each rank will redirect stdout and stderr into their own file inside the directory.
This can be helpful to avoid having interleaved stack dumps from different ranks in stderr.

### Performance Metrics

When deploying to kubernetes clusters, prometheus + grafana can be installed and configured to scrape metrics from vLLM's `/metrics` endpoint.

vLLM can also be configured to log performance metrics about every request to a local file.
Setting both `VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED=1` and `VLLM_SPYRE_PERF_METRIC_LOGGING_DIR=/some/path` and ensuring that vLLM stat logging is enabled will generate metrics in `/some/path/request_metrics.jsonl`. A sample of this file looks like:

```json
{"timestamp": "2025-10-10T12:25:17.544", "prefill_interrupt_seconds": 0, "decode_only_itl_seconds": 0.05045744727055232, "finish_reason": 1, "num_prompt_tokens": 1, "num_generation_tokens": 16, "max_tokens_param": 16, "e2e_latency_seconds": 0.9784879684448242, "queued_time_seconds": 6.0582999140024185e-05, "prefill_time_seconds": 0.220398832927458, "inference_time_seconds": 0.9772605419857427, "decode_time_seconds": 0.7568617090582848, "mean_time_per_output_token_seconds": 0.05045744727055232}
{"timestamp": "2025-10-10T12:25:19.632", "prefill_interrupt_seconds": 0, "decode_only_itl_seconds": 0.10008190000274529, "finish_reason": 1, "num_prompt_tokens": 1, "num_generation_tokens": 16, "max_tokens_param": 16, "e2e_latency_seconds": 2.0864057540893555, "queued_time_seconds": 0.2935298749944195, "prefill_time_seconds": 0.1466117500094697, "inference_time_seconds": 1.647840250050649, "decode_time_seconds": 1.5012285000411794, "mean_time_per_output_token_seconds": 0.10008190000274529}
{"timestamp": "2025-10-10T12:25:19.632", "prefill_interrupt_seconds": 0.14661192893981934, "decode_only_itl_seconds": 0.1000875825372835, "finish_reason": 1, "num_prompt_tokens": 1, "num_generation_tokens": 16, "max_tokens_param": 16, "e2e_latency_seconds": 2.0864808559417725, "queued_time_seconds": 0.1469848749693483, "prefill_time_seconds": 0.14646116609219462, "inference_time_seconds": 1.7943868330912665, "decode_time_seconds": 1.6479256669990718, "mean_time_per_output_token_seconds": 0.10986171113327145}
{"timestamp": "2025-10-10T12:25:19.632", "prefill_interrupt_seconds": 0.29317212104797363, "decode_only_itl_seconds": 0.10008799746477355, "finish_reason": 1, "num_prompt_tokens": 1, "num_generation_tokens": 16, "max_tokens_param": 16, "e2e_latency_seconds": 2.08658504486084, "queued_time_seconds": 0.0001724999165162444, "prefill_time_seconds": 0.14670966705307364, "inference_time_seconds": 1.9412017500726506, "decode_time_seconds": 1.794492083019577, "mean_time_per_output_token_seconds": 0.11963280553463847}
{"timestamp": "2025-10-10T12:25:19.632", "prefill_interrupt_seconds": 0.4400491714477539, "decode_only_itl_seconds": 0.10009045804229875, "finish_reason": 1, "num_prompt_tokens": 1, "num_generation_tokens": 16, "max_tokens_param": 16, "e2e_latency_seconds": 2.0868380069732666, "queued_time_seconds": 2.9250048100948334e-05, "prefill_time_seconds": 0.1447284579044208, "inference_time_seconds": 2.086134499986656, "decode_time_seconds": 1.9414060420822352, "mean_time_per_output_token_seconds": 0.12942706947214902}
```

### Topology Aware Allocation

This section is specific to the AIU operator and scheduling workloads onto specific cards.
Expand Down
37 changes: 37 additions & 0 deletions tests/e2e/test_stats_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json
from pathlib import Path

import pytest
from spyre_util import ModelInfo, get_chicken_soup_prompts
from vllm import LLM

from vllm_spyre import envs as envs_spyre


@pytest.mark.cpu
@pytest.mark.cb
def test_file_stats_logger(model: ModelInfo, max_model_len, max_num_seqs,
tmp_path):

prompts = get_chicken_soup_prompts(4)

envs_spyre.override("VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED", "1")
envs_spyre.override("VLLM_SPYRE_PERF_METRIC_LOGGING_DIR", str(tmp_path))
envs_spyre.override("VLLM_SPYRE_USE_CB", "1")
envs_spyre.override("VLLM_SPYRE_DYNAMO_BACKEND", "eager")

model = LLM(model=model.name,
revision=model.revision,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
disable_log_stats=False)
model.generate(prompts=prompts)

assert Path(tmp_path / "request_metrics.jsonl").exists()

with Path(tmp_path / "request_metrics.jsonl").open() as f:
for line in f.readlines():
data = json.loads(line)
assert "prefill_interrupt_seconds" in data
assert "e2e_latency_seconds" in data
assert "timestamp" in data
9 changes: 8 additions & 1 deletion vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,14 @@ def _backend_backwards_compat() -> str:
lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))),

# Enable performance metric logging. This captures startup information
# such as warmup times, and loading times. It is turned off by default.
# such as warmup times, and loading times.
# When `--disable-log-stats=False` is used, this will log timing metrics
# about every finished request into a .jsonl file. These are the same
# metrics that are available in prometheus format on the /metrics endpoint,
# but it is sometime helpful to view them disaggregated to debug performance
# problems. This logging is not designed to be performant, and should not be
# enabled in production settings.
# It is turned off by default.
"VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED":
lambda: int(os.getenv("VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED", 0)),

Expand Down
1 change: 1 addition & 0 deletions vllm_spyre/perf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, rank: int):
self.time_fmt = "%m-%d %H:%M:%S"
self.log_path = os.path.join(envs.VLLM_SPYRE_PERF_METRIC_LOGGING_DIR,
f"perf_log_rank_{str(rank)}.txt")
os.makedirs(envs.VLLM_SPYRE_PERF_METRIC_LOGGING_DIR, exist_ok=True)
# Cleanup previous metrics files
if os.path.exists(self.log_path):
os.remove(self.log_path)
Expand Down
3 changes: 3 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# 🌶️🌶️🌶️ Patch in our perf logger before the engine is created
from vllm_spyre.v1.metrics import patch_async_llm_stat_loggers
patch_async_llm_stat_loggers()

# In case vllm passes a default vllm_config to us.
# This happens when get_current_vllm_config is called
Expand Down
7 changes: 7 additions & 0 deletions vllm_spyre/v1/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .stats_logger import (FileStatLogger, file_stat_logger_factory,
patch_async_llm_stat_loggers)

__all__ = [
"patch_async_llm_stat_loggers", "file_stat_logger_factory",
"FileStatLogger"
]
210 changes: 210 additions & 0 deletions vllm_spyre/v1/metrics/stats_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import dataclasses
import json
import time
from datetime import datetime
from functools import wraps
from pathlib import Path
from typing import Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.engine import async_llm, llm_engine
from vllm.v1.metrics.loggers import StatLoggerBase, StatLoggerManager
from vllm.v1.metrics.stats import (FinishedRequestStats, IterationStats,
SchedulerStats)

from vllm_spyre import envs as envs_spyre

logger = init_logger(__name__)


@dataclasses.dataclass
class PerfRecord:
"""A record for request_metrics.jsonl.
Contains info about a single finished request"""
# ISO timestamp w/ milliseconds
timestamp: str
# timing info
engine_stats: FinishedRequestStats
# time spent pre-empted for other prefills
prefill_interrupt_seconds: float
# ITL calculated without the prefill interrupts
decode_only_itl_seconds: float

# key names to append with a time unit during json serialization
_TIME_KEYS = [
"e2e_latency", "queued_time", "prefill_time", "inference_time",
"decode_time", "mean_time_per_output_token"
]

def to_json(self) -> str:
json_dict = dataclasses.asdict(self)

# Flatten the engine stats into the top level
engine_dict = json_dict.pop("engine_stats")
json_dict.update(engine_dict)

# add _seconds onto the timing info from the engine
for k in self._TIME_KEYS:
if k in json_dict:
json_dict[k + "_seconds"] = json_dict.pop(k)

return json.dumps(json_dict)


class FileStatLogger(StatLoggerBase):

def __init__(self, vllm_config: VllmConfig, engine_index=0):
self.enabled = (envs_spyre.VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED)

perf_dir = Path(envs_spyre.VLLM_SPYRE_PERF_METRIC_LOGGING_DIR)
if not perf_dir.exists():
perf_dir.mkdir(parents=True)

self.perf_file = Path(envs_spyre.VLLM_SPYRE_PERF_METRIC_LOGGING_DIR
) / "request_metrics.jsonl"

if self.enabled and engine_index == 0:
logger.info(
"Initializing vllm-spyre perf debug logger. Writing perf info "
"to: %s", str(self.perf_file))

# Clear any old metrics out first
if self.perf_file.exists():
self.perf_file.unlink()

self.perf_file.touch()

self.iso_format = "%Y-%m-%dT%H:%M:%S.%f"

self._prefill_tuples: list[tuple[float, float]] = []
self._max_batch_size = vllm_config.scheduler_config.max_num_seqs
self._last_ts: float = 0

def record(self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
if not self.enabled or engine_idx != 0:
# Only log from rank 0
return

if iteration_stats is None:
return

if iteration_stats.num_prompt_tokens > 0:
self._save_prefill_time(iteration_stats)
self._last_ts = iteration_stats.iteration_timestamp

if not iteration_stats.finished_requests:
# Only log finished requests
return

### Convert float timestamp to human readable string
text_timestamp = datetime.fromtimestamp(
iteration_stats.iteration_timestamp).strftime(self.iso_format)[:-3]

records_to_write: list[str] = []
for r in iteration_stats.finished_requests:
# Calculate some estimates to add to the engine stats
estimated_prefill_interrupt = \
self.estimate_prefill_interrupt_lower_bound(r)

estimated_decode_itl = (r.decode_time -
estimated_prefill_interrupt) / max(
r.num_generation_tokens - 1, 1)

record = PerfRecord(
timestamp=text_timestamp,
engine_stats=r,
decode_only_itl_seconds=estimated_decode_itl,
prefill_interrupt_seconds=estimated_prefill_interrupt)
records_to_write.append(record.to_json())

with open(self.perf_file, "a") as f:
f.write("\n".join(records_to_write) + "\n")

def log_engine_initialized(self):
pass

def _save_prefill_time(self, iteration_stats: IterationStats):
"""If this iteration was a prefill, then save the a tuple of the current
time and prefill time. This will be used later to estimate a lower bound
of the amount of time that other sequences were
interrupted for this prefill to happen.

This is only relevant because the batching implementation has to pause
the running batch of decoding sequences to prefill a single sequence.
"""
maybe_prefill_time = iteration_stats.iteration_timestamp - self._last_ts
# TTFT here includes queueing and we don't have access to the iteration
# duration itself so we have to try to calculate our own prefill time.
# If we calculate an interval that was less than the reported TTFT, then
# use it as the prefill time
maybe_prefill_time = min(maybe_prefill_time,
iteration_stats.time_to_first_tokens_iter[0])

# Tuple is (timestamp, prefill_time)
self._prefill_tuples.append(
(iteration_stats.iteration_timestamp, maybe_prefill_time))
if len(self._prefill_tuples) > 2 * self._max_batch_size:
# Delete older prefills, we can't hold everything in memory
# Not guaranteed to be lossless
self._prefill_tuples.pop(0)

def estimate_prefill_interrupt_lower_bound(
self, finished_request: FinishedRequestStats) -> float:
"""Returns a lower bound estimate on the time (in ms) that this request
was interrupted for other requests to prefill to join the batch"""
estimated_prefill_interrupt: float = 0

# NB: use current time instead of iteration timestamp to ensure that we
# exclude current request's prefill
slop = 0.001
decode_start_time = time.time() - finished_request.decode_time + slop
for i in range(len(self._prefill_tuples)):
if self._prefill_tuples[i][0] > decode_start_time:
# Sum up all prefills past decode start time
estimated_prefill_interrupt = sum(
r[1] for r in self._prefill_tuples[i:])
break
return estimated_prefill_interrupt


def file_stat_logger_factory(config: VllmConfig,
engine_index=0) -> FileStatLogger:
"""Factory method accepted by vllm engine initializers"""
return FileStatLogger(config, engine_index)


def patch_async_llm_stat_loggers():
"""
🌶️🌶️🌶️
Platforms cannot alter the initialization of a vllm engine, and the
`stat_loggers` parameter is not user-settable via `EngineArgs`.

So we resort to patching the initialization of the StatsLoggerManager to
inject our own stats logger. This _should_ also be compatible with versions
of vllm prior to the addition of `stats_loggers` engine parameter.
🌶️🌶️🌶️
"""
logger.debug("Setting up perf logger injection")
original_init = StatLoggerManager.__init__

@wraps(original_init)
def new_init(self, *args, **kwargs):
logger.debug("Injecting vllm-spyre perf logger factory")
if "custom_stat_loggers" not in kwargs or kwargs[
"custom_stat_loggers"] is None:
kwargs["custom_stat_loggers"] = []

kwargs["custom_stat_loggers"].append(file_stat_logger_factory)

original_init(self, *args, **kwargs)

async_llm.StatLoggerManager.__init__ = new_init
if hasattr(llm_engine, "StatLoggerManager"):
## 0.10.2 backwards compatibility
# Once the lower bound is past that, remove the if check but keep this
# line. The `llm_engine` package uses StatLoggerManagers on 0.11.0+
llm_engine.StatLoggerManager.__init__ = new_init