Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b8d2fa0
[Core][Feature] Input metadata dump on crash
wallashss Feb 17, 2025
08c9f15
fix: mypy complaints
wallashss Feb 18, 2025
0ef83a8
fix: mypy complaints
wallashss Feb 18, 2025
84fe0af
Merge branch 'main' into dump-input-on-crash
wallashss Feb 24, 2025
d8f75b7
fix: dump report
wallashss Feb 25, 2025
d588fff
fix server hang on shutdown
wallashss Feb 25, 2025
58cd9c9
Merge branch 'main' into dump-input-on-crash
wallashss Feb 26, 2025
4eafa5e
Merge branch 'main' of github.com:wallashss/vllm into dump-input-on-c…
wallashss Feb 26, 2025
e3d945f
Merge branch 'main' into dump-input-on-crash
wallashss Mar 5, 2025
ea544f1
review feedback
wallashss Mar 5, 2025
63f24ab
moved vllm/error_report.py to vllm/logging_utils/dump_input.py
wallashss Mar 5, 2025
7d405d4
refact for review
wallashss Mar 7, 2025
b4a83eb
refactoring
wallashss Mar 7, 2025
02e1673
refactoring
wallashss Mar 7, 2025
6fe83ca
Merge branch 'main' into dump-input-on-crash
wallashss Mar 7, 2025
f66489d
fix lint
wallashss Mar 7, 2025
5f82648
reverted change on llm_engine due to test
wallashss Mar 13, 2025
93ae6dc
fix: ensure suppress exception in dump
wallashss Mar 14, 2025
1320c87
Merge branch 'main' into dump-input-on-crash
wallashss Mar 18, 2025
9764218
Merge branch 'main' into dump-input-on-crash
wallashss Mar 25, 2025
80b9751
Merge branch 'main' into dump-input-on-crash
wallashss Apr 2, 2025
3384055
feat: removed v0
wallashss Apr 8, 2025
2f72a2f
Merge branch 'main' of github.com:wallashss/vllm into dump-input-on-c…
wallashss Apr 8, 2025
ad368f1
removed v0 support
wallashss Apr 8, 2025
7c18e20
refact: code clean up
wallashss Apr 10, 2025
94fa3fc
Merge branch 'main' of github.com:wallashss/vllm into dump-input-on-c…
wallashss Apr 10, 2025
307d3cf
Merge branch 'main' into dump-input-on-crash
wallashss May 1, 2025
8cbee30
refact: moved execute model to a separated method
wallashss May 1, 2025
51596e4
Update vllm/v1/engine/core.py
wallashss May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/400-bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ body:
```

```
The error message you got, with the full traceback.
The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
```
validations:
required: true
Expand Down
49 changes: 43 additions & 6 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
"""
import os
import weakref
from unittest.mock import Mock

import pytest

from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1

from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
Expand Down Expand Up @@ -152,9 +154,44 @@ def test_models_distributed(
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


def test_failed_model_execution(vllm_runner, monkeypatch) -> None:

from vllm.envs import VLLM_USE_V1

if not VLLM_USE_V1:
pytest.skip("Skipping V0 test, dump input not supported")

# Needed to mock an error in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')

with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
if isinstance(vllm_model.model.llm_engine, LLMEngineV1):
v1_test_failed_model_execution(vllm_model)


def v1_test_failed_model_execution(vllm_model):

engine = vllm_model.model.llm_engine
mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error"))
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model

with pytest.raises(RuntimeError) as exc_info:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
assert isinstance(exc_info.value, RuntimeError)
assert "Mocked Critical Error" in str(exc_info.value)
84 changes: 84 additions & 0 deletions vllm/logging_utils/dump_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import enum
import json
from typing import Optional

import torch

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.metrics.stats import SchedulerStats
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)


def prepare_object_to_dump(obj) -> str:
if isinstance(obj, str):
return "'{obj}'" # Double quotes
elif isinstance(obj, dict):
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
for k, v in obj.items()})
return f'{{{dict_str}}}'
elif isinstance(obj, list):
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
elif isinstance(obj, set):
return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]"
# return [prepare_object_to_dump(v) for v in list(obj)]
elif isinstance(obj, tuple):
return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]"
elif isinstance(obj, enum.Enum):
return repr(obj)
elif isinstance(obj, torch.Tensor):
# We only print the 'draft' of the tensor to not expose sensitive data
# and to get some metadata in case of CUDA runtime crashed
return (f"Tensor(shape={obj.shape}, "
f"device={obj.device},"
f"dtype={obj.dtype})")
elif hasattr(obj, 'anon_repr'):
return obj.anon_repr()
elif hasattr(obj, '__dict__'):
items = obj.__dict__.items()
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \
for k, v in items])
return (f"{type(obj).__name__}({dict_str})")
else:
# Hacky way to make sure we can serialize the object in JSON format
try:
return json.dumps(obj)
except (TypeError, OverflowError):
return repr(obj)


def dump_engine_exception(config: VllmConfig,
scheduler_output: SchedulerOutput,
scheduler_stats: Optional[SchedulerStats]):
# NOTE: ensure we can log extra info without risking raises
# unexpected errors during logging
with contextlib.suppress(BaseException):
_dump_engine_exception(config, scheduler_output, scheduler_stats)


def _dump_engine_exception(config: VllmConfig,
scheduler_output: SchedulerOutput,
scheduler_stats: Optional[SchedulerStats]):
logger.error("Dumping input data")

logger.error(
"V1 LLM engine (v%s) with config: %s, ",
VLLM_VERSION,
config,
)

try:
dump_obj = prepare_object_to_dump(scheduler_output)
logger.error("Dumping scheduler output for model execution:")
logger.error(dump_obj)
if scheduler_stats:
logger.error(scheduler_stats)
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))
29 changes: 29 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,35 @@ def from_request(
lora_request=request.lora_request,
)

def __repr__(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"prompt={self.prompt},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")

# Version of __repr__ with the prompt data obfuscated
def anon_repr(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"prompt='',"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")


@dataclass
class CachedRequestData:
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
):
assert vllm_config.model_config.runner_type != "pooling"

self.vllm_config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)

Expand Down Expand Up @@ -203,7 +205,15 @@ def step(self) -> EngineCoreOutputs:
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
try:
output = self.model_executor.execute_model(scheduler_output)
except BaseException as err:
# NOTE: This method is exception-free
dump_engine_exception(self.vllm_config, scheduler_output,
self.scheduler.make_stats())
# Re-raise exception
raise err

engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore

Expand Down