Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b8d2fa0
[Core][Feature] Input metadata dump on crash
wallashss Feb 17, 2025
08c9f15
fix: mypy complaints
wallashss Feb 18, 2025
0ef83a8
fix: mypy complaints
wallashss Feb 18, 2025
84fe0af
Merge branch 'main' into dump-input-on-crash
wallashss Feb 24, 2025
d8f75b7
fix: dump report
wallashss Feb 25, 2025
d588fff
fix server hang on shutdown
wallashss Feb 25, 2025
58cd9c9
Merge branch 'main' into dump-input-on-crash
wallashss Feb 26, 2025
4eafa5e
Merge branch 'main' of github.com:wallashss/vllm into dump-input-on-c…
wallashss Feb 26, 2025
e3d945f
Merge branch 'main' into dump-input-on-crash
wallashss Mar 5, 2025
ea544f1
review feedback
wallashss Mar 5, 2025
63f24ab
moved vllm/error_report.py to vllm/logging_utils/dump_input.py
wallashss Mar 5, 2025
7d405d4
refact for review
wallashss Mar 7, 2025
b4a83eb
refactoring
wallashss Mar 7, 2025
02e1673
refactoring
wallashss Mar 7, 2025
6fe83ca
Merge branch 'main' into dump-input-on-crash
wallashss Mar 7, 2025
f66489d
fix lint
wallashss Mar 7, 2025
5f82648
reverted change on llm_engine due to test
wallashss Mar 13, 2025
93ae6dc
fix: ensure suppress exception in dump
wallashss Mar 14, 2025
1320c87
Merge branch 'main' into dump-input-on-crash
wallashss Mar 18, 2025
9764218
Merge branch 'main' into dump-input-on-crash
wallashss Mar 25, 2025
80b9751
Merge branch 'main' into dump-input-on-crash
wallashss Apr 2, 2025
3384055
feat: removed v0
wallashss Apr 8, 2025
2f72a2f
Merge branch 'main' of github.com:wallashss/vllm into dump-input-on-c…
wallashss Apr 8, 2025
ad368f1
removed v0 support
wallashss Apr 8, 2025
7c18e20
refact: code clean up
wallashss Apr 10, 2025
94fa3fc
Merge branch 'main' of github.com:wallashss/vllm into dump-input-on-c…
wallashss Apr 10, 2025
307d3cf
Merge branch 'main' into dump-input-on-crash
wallashss May 1, 2025
8cbee30
refact: moved execute model to a separated method
wallashss May 1, 2025
51596e4
Update vllm/v1/engine/core.py
wallashss May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
"""
import os
import weakref
from unittest.mock import Mock

import pytest

from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.core import ModelExecutionV1Error
from vllm.v1.engine.core_client import EngineCoreClient, InprocClient
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from vllm.worker.worker_base import ModelExecutionError

from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
Expand Down Expand Up @@ -147,3 +152,47 @@ def test_models_distributed(
name_0="hf",
name_1="vllm",
)


def test_failed_model_execution(vllm_runner) -> None:

def make_client(
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config, # "VllmConfig"
executor_class, # "Type[Executor]"
log_stats: bool,
) -> "EngineCoreClient":
return InprocClient(vllm_config, executor_class, log_stats)

EngineCoreClient.make_client = Mock(side_effect=make_client)
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:

engine = vllm_model.model.llm_engine
mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error"))

if isinstance(engine, LLMEngineV1):
is_v1 = True
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
else: # V0
is_v1 = False
engine.model_executor.driver_worker.model_runner.execute_model = \
mocked_execute_model

with pytest.raises(RuntimeError) as exc_info:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
if is_v1:
assert isinstance(exc_info.value, ModelExecutionV1Error)
assert exc_info.value.scheduler_output is not None
else:
assert isinstance(exc_info.value, ModelExecutionError)
assert exc_info.value.model_input is not None
assert "Mocked Critical Error" in str(exc_info.value)
19 changes: 17 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.error_report import dump_engine_exception
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
Expand Down Expand Up @@ -1388,8 +1389,22 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]

outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
try:
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)

except BaseException as err:
stats = self._get_stats(scheduler_outputs=scheduler_outputs)
dump_engine_exception(
err=err,
config=self.vllm_config,
use_cached_outputs=self.use_cached_outputs,
engine_version=0,
stats=stats,
execute_model_req=execute_model_req,
)
# Re-raise exception
raise err

# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
Expand Down
159 changes: 159 additions & 0 deletions vllm/error_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0

import enum
import json
from typing import Any, Dict, Union

import torch

from vllm.config import VllmConfig
from vllm.engine.metrics import Stats
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SequenceData
from vllm.v1.core.scheduler_output import NewRequestData
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.worker_base import ModelExecutionError

logger = init_logger(__name__)


# Hacky way to make sure we can serialize the object in JSON format
def is_json_serializable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False


def prepare_object_to_dump(obj):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is basically the root cause of removing the previous input dump. It has many cases to handle and will be affected if any of them is changed. Specifically, primitive types and torch.Tensor are fine, but I'm a bit worry about SequenceData and NewRequestData.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I share the same concern.

The custom handler for theses classes is to obfuscate the prompts. But I can not anticipate that we always have the right implementation for future changes. I guess we could add more hardcoded logs, comments, asserts, and tests to warn other developers of this feature at the cost of increase the maintenance of this feature. I am not sure of this, but I would like to hear more feedback or ideas from your side.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems create burden to developers and this is the reason of removing input dump. Ideally we could have an approach to recursively traverse an input object and serialize them with tensor values ignored. Another direction is providing these methods in custom data structures (e.g., SequenceData.dump()) so that they can be in the same place to ease the maintenance.

Also cc @youkaichao

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a method called anon_repr for those classes, which is similar to the __repr__ implementation. They are close and I added comment there to help other contributors to be aware of that. The prepare_object_to_dump is has indirect awareness of this method, it check in the serialization if the object contains this method, and use it if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW: I changed the format of dump to be more like how __repr__ outputs string representation of objects instead of JSON. I think it got more standardized and consistent with what we already have been using with __repr__.

if isinstance(obj, dict):
return {k: prepare_object_to_dump(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [prepare_object_to_dump(v) for v in obj]
elif isinstance(obj, set):
return [prepare_object_to_dump(v) for v in list(obj)]
elif isinstance(obj, tuple):
return [prepare_object_to_dump(v) for v in obj]
elif isinstance(obj, enum.Enum):
return repr(obj)
elif isinstance(obj, SequenceData):
# Custom representation (based on SequenceData.__repr__)
# to obfuscate some parameters
return {
"class": "SequenceData",
"prompt_token_ids_len": len(obj._prompt_token_ids),
"output_token_ids_len": len(obj.output_token_ids),
"cumulative_logprob": obj.cumulative_logprob,
"get_num_computed_tokens": obj.get_num_computed_tokens()
}

elif isinstance(obj, NewRequestData):
obj_dict: Dict[str, Any] = {'class': type(obj).__name__}
for k, v in obj.__dict__.items():
if k == 'prompt_token_ids':
obj_dict['prompt_token_ids_len'] = len(v)
elif k == 'prompt':
obj_dict['prompt'] = ""
else:
obj_dict[k] = prepare_object_to_dump(v)

return obj_dict
elif isinstance(obj, torch.Tensor):
# We only print the 'draft'of the tensor to not expose sensitive data
# and to get some metadata in case of CUDA illegal memory access
return (f"Tensor(shape={obj.shape}, "
f"device={obj.device},"
f"dtype={obj.dtype})")
elif hasattr(obj, '__dict__'):
obj_dict = {'class': type(obj).__name__}
obj_dict.update(obj.__dict__)
return prepare_object_to_dump(obj_dict)
else:
# Try to make sure we can serialize the object
# to avoid exception
if is_json_serializable(obj):
return obj
else:
return repr(obj)


def dump_engine_exception(err: BaseException,
config: VllmConfig,
engine_version: int,
stats: Union[Stats, None] = None,
use_cached_outputs: Union[bool, None] = None,
execute_model_req: Union[ExecuteModelRequest,
None] = None):

assert engine_version == 0 or engine_version == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we don't need to support v0. Reasons:

  1. The code could be much cleaner.
  2. v0 is going to be deprecated soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but I guess that are still a lot of deployments running right now that are based on V0 (at least from our side). That's why we are interested in support both engines.

Do you think you can reconsider?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're planning to freeze v0 I still don't feel we should support it. However if you really need, I'd suggest that we separate the v0/v1 logic completely in different functions (e.g. xxx_v0), so that in the future when we want to deprecate v0, we can easily locate the logic and remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I splitted the functions to ease the identification of the version.


logger.error("Engine crashed, dumping input data")

if engine_version == 1:
logger.error(
"V1 LLM engine (v%s) with config: %s, ",
VLLM_VERSION,
config,
)
else:
logger.error(
"V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, ",
VLLM_VERSION,
config,
use_cached_outputs,
)

# For V0
if isinstance(err, ModelExecutionError):
try:
err_json = prepare_object_to_dump(err.model_input)
logger.error("Model input for execution as JSON:")
logger.error(json.dumps(err_json))
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))

# In case we do not have a ModelExecutionError we still can
# get information from the batch
if execute_model_req is not None:
batch = execute_model_req.seq_group_metadata_list
requests_count = len(batch)

requests_prompt_token_ids_lenghts = [{
k: len(v.prompt_token_ids)
for (k, v) in r.seq_data.items()
} for r in batch]

requests_ids = ', '.join([str(r.request_id) for r in batch])
logger.error(
"Batch info: requests_count=%s, "
"requests_prompt_token_ids_lenghts=(%s), "
"requests_ids=(%s)", requests_count,
requests_prompt_token_ids_lenghts, requests_ids)

for idx, r in enumerate(batch):
logger.error(
"Errored Batch request #%s: request_id=%s "
"prompt_token_ids_lengths=%s, "
"params=%s, "
"lora_request=%s, prompt_adapter_request=%s ", idx,
r.request_id, str(len(r.seq_data[idx].prompt_token_ids)),
r.sampling_params, r.lora_request, r.prompt_adapter_request)

# TODO: Have stats for V1
if stats is not None:
logger.error("System stats:")
logger.error(stats)

if engine_version == 1:
from vllm.v1.engine.core import ModelExecutionV1Error
if isinstance(err, ModelExecutionV1Error):
try:
err_json = prepare_object_to_dump(err.scheduler_output)
logger.error("Scheduler output for model execution as JSON:")
logger.error(json.dumps(err_json))
except BaseException as exception:
logger.error("Error preparing object to dump")
logger.error(repr(exception))
24 changes: 23 additions & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import zmq.asyncio

from vllm.config import VllmConfig
from vllm.error_report import dump_engine_exception
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
Expand All @@ -36,6 +37,18 @@
POLLING_TIMEOUT_S = 2.5


class ModelExecutionV1Error(RuntimeError):
scheduler_output: SchedulerOutput

def __init__(self, *args, scheduler_output=None):
super().__init__(*args)
self.scheduler_output = scheduler_output

def __reduce__(self):
# To avoid pickle errors
return (self.__class__, (self.args[0], ))


class EngineCore:
"""Inner loop of vLLM's Engine."""

Expand All @@ -47,6 +60,7 @@ def __init__(
):
assert vllm_config.model_config.runner_type != "pooling"

self.config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)

Expand Down Expand Up @@ -151,7 +165,15 @@ def step(self) -> EngineCoreOutputs:
outputs=[], scheduler_stats=self.scheduler.make_stats())

scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
try:
output = self.model_executor.execute_model(scheduler_output)
except BaseException as err:
err = ModelExecutionV1Error(
f"Model execution failure,"
f"reason: {repr(err)}",
scheduler_output=scheduler_output)
dump_engine_exception(err, self.config, 1)
raise err
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
return engine_core_outputs
Expand Down
55 changes: 39 additions & 16 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")


class ModelExecutionError(RuntimeError):
model_input: BroadcastableModelInput

def __init__(self, *args, model_input=None):
super().__init__(*args)
self.model_input = model_input

def __reduce__(self):
# To avoid pickle errors
return (self.__class__, (self.args[0], ))


@dataclasses.dataclass(frozen=True)
class WorkerInput:
"""Local inputs to each worker. May contain device-specific data. These
Expand Down Expand Up @@ -416,15 +428,20 @@ def execute_model(
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()

output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
try:
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
except BaseException as err:
raise ModelExecutionError(
f"Model execution failure,"
f"reason: {repr(err)}",
model_input=model_input) from err

model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
Expand Down Expand Up @@ -474,13 +491,19 @@ def _execute_model_spmd(

kwargs = extract_previous_hidden_states(execute_model_req)

return self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
try:
return self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
except BaseException as err:
raise ModelExecutionError(
f"Model execution failure,"
f"reason: {repr(err)}",
model_input=model_input) from err


class WorkerWrapperBase:
Expand Down