Skip to content

Conversation

@wallashss
Copy link
Contributor

@wallashss wallashss commented Feb 17, 2025

This PR adds a feature to dump input metadata when vllm engine crashes. In essence, this change is the spiritual successor to #8305 that was recently removed in #12582. However, I tried to solve it differently, since this feature can give us more hints to help debug crashes in production environment. So, I would like to propose it again to the community and give it a second chance.

Summary:

  • The dump is just logged (instead of pickle in [MISC] Dump model runner inputs when crashing #8305)
  • Developed for both engines V0 and V1
  • Dump only tensor metadata, to be able to dump them on CUDA crashes and have their contents obfuscated to not leak sensitive data
  • Introduced custom exceptions, that might be useful for other types of custom error handling
  • Some fields are removed like the prompt or the prompt token ids to avoid log sensitive data in production environment
  • Dump system stats, to check the system status in the moment of the crash. TODO: for V1
  • Print engine config again, to have a chance to get the setup in truncated logs
V0 dump sample
ERROR 03-07 12:56:03 [dump_input.py:101] Dumping model input for execution:
ERROR 03-07 12:56:03 [dump_input.py:102] ModelInputForGPUWithSamplingMetadata(input_tokens=Tensor(shape=torch.Size([26]), device=cuda:0,dtype=torch.int64),input_positions=Tensor(shape=torch.Size([26]), device=cuda:0,dtype=torch.int64),token_types=null,seq_lens=[6, 8, 6, 6],query_lens=[6, 8, 6, 6],lora_mapping=null,lora_requests=[],attn_metadata=FlashAttentionMetadata(num_prefills=4,num_prefill_tokens=26,num_decode_tokens=0,slot_mapping=Tensor(shape=torch.Size([26]), device=cuda:0,dtype=torch.int64),multi_modal_placeholder_index_maps={},enable_kv_scales_calculation=true,seq_lens=[6, 8, 6, 6],seq_lens_tensor=Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int32),max_prefill_seq_len=8,max_decode_seq_len=0,context_lens_tensor=Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int32),block_tables=Tensor(shape=torch.Size([4, 0]), device=cuda:0,dtype=torch.int32),use_cuda_graph=false,max_query_len=8,max_decode_query_len=1,query_start_loc=Tensor(shape=torch.Size([5]), device=cuda:0,dtype=torch.int32),seq_start_loc=Tensor(shape=torch.Size([5]), device=cuda:0,dtype=torch.int32),_cached_prefill_metadata=null,_cached_decode_metadata=null,encoder_seq_lens=null,encoder_seq_lens_tensor=null,encoder_seq_start_loc=null,max_encoder_seq_len=null,num_encoder_tokens=null,cross_slot_mapping=null,cross_block_tables=null),prompt_adapter_mapping=null,prompt_adapter_requests=[],multi_modal_kwargs={},request_ids_to_seq_ids={2: [2], 1: [1], 3: [3], 0: [0]},finished_requests_ids=[],virtual_engine=0,async_callback=null,scheduler_outputs=null,previous_hidden_states=null,sampling_metadata=SamplingMetadata(seq_groups=[SequenceGroupToSample(seq_ids=[0],sampling_params=SamplingParams(sampling_type=<SamplingType.GREEDY: 0>),seq_data={0: SequenceData(prompt_token_ids_len=6, output_token_ids_len=0, cumulative_logprob=0.0, get_num_computed_tokens=0)},seq_len=6,query_len=6,generator=null,is_prompt=true,prompt_logprob_indices=[],sample_indices=[0]), SequenceGroupToSample(seq_ids=[1],sampling_params=SamplingParams(sampling_type=<SamplingType.GREEDY: 0>),seq_data={1: SequenceData(prompt_token_ids_len=8, output_token_ids_len=0, cumulative_logprob=0.0, get_num_computed_tokens=0)},seq_len=8,query_len=8,generator=null,is_prompt=true,prompt_logprob_indices=[],sample_indices=[1]), SequenceGroupToSample(seq_ids=[2],sampling_params=SamplingParams(sampling_type=<SamplingType.GREEDY: 0>),seq_data={2: SequenceData(prompt_token_ids_len=6, output_token_ids_len=0, cumulative_logprob=0.0, get_num_computed_tokens=0)},seq_len=6,query_len=6,generator=null,is_prompt=true,prompt_logprob_indices=[],sample_indices=[2]), SequenceGroupToSample(seq_ids=[3],sampling_params=SamplingParams(sampling_type=<SamplingType.GREEDY: 0>),seq_data={3: SequenceData(prompt_token_ids_len=6, output_token_ids_len=0, cumulative_logprob=0.0, get_num_computed_tokens=0)},seq_len=6,query_len=6,generator=null,is_prompt=true,prompt_logprob_indices=[],sample_indices=[3])],selected_token_indices=Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int64),categorized_sample_indices={0: Tensor(shape=torch.Size([4]), device=cuda:0,dtype=torch.int32), 1: Tensor(shape=torch.Size([0]), device=cuda:0,dtype=torch.int32), 2: Tensor(shape=torch.Size([0]), device=cuda:0,dtype=torch.int32)},num_prompts=4,skip_sampler_cpu_output=false,reuse_sampling_tensors=false),is_prompt=true)
ERROR 03-07 12:56:03 [dump_input.py:120] Batch info: requests_count=4, requests_prompt_token_ids_lenghts=([{0: 6}, {1: 8}, {2: 6}, {3: 6}]), requests_ids=(0, 1, 2, 3)
ERROR 03-07 12:56:03 [dump_input.py:127] Errored Batch request #0: request_id=0 prompt_token_ids_lengths=6, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 03-07 12:56:03 [dump_input.py:127] Errored Batch request #1: request_id=1 prompt_token_ids_lengths=8, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 03-07 12:56:03 [dump_input.py:127] Errored Batch request #2: request_id=2 prompt_token_ids_lengths=6, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 03-07 12:56:03 [dump_input.py:127] Errored Batch request #3: request_id=3 prompt_token_ids_lengths=6, params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None), lora_request=None, prompt_adapter_request=None 
ERROR 03-07 12:56:03 [dump_input.py:136] System stats:
ERROR 03-07 12:56:03 [dump_input.py:137] Stats(now=1741352163.1743438, num_running_sys=4, num_waiting_sys=0, num_swapped_sys=0, gpu_cache_usage_sys=3.115750116844396e-05, cpu_cache_usage_sys=0.0, cpu_prefix_cache_hit_rate=-1, gpu_prefix_cache_hit_rate=-1, num_prompt_tokens_iter=26, num_generation_tokens_iter=0, num_tokens_iter=26, time_to_first_tokens_iter=[], time_per_output_tokens_iter=[], num_preemption_iter=0, time_e2e_requests=[], time_queue_requests=[], time_inference_requests=[], time_prefill_requests=[], time_decode_requests=[], time_in_queue_requests=[], model_forward_time_requests=[], model_execute_time_requests=[], num_prompt_tokens_requests=[], num_generation_tokens_requests=[], n_requests=[], max_num_generation_tokens_requests=[], max_tokens_requests=[], finished_reason_requests=[], waiting_lora_adapters=[], running_lora_adapters=[], max_lora='0', spec_decode_metrics=None)
V1 Dump Sample
ERROR 03-07 12:55:56 [dump_input.py:58] Dumping input data
ERROR 03-07 12:55:56 [dump_input.py:60] V1 LLM engine (v0.6.6.dev967+g07c435305) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}, 
ERROR 03-07 12:55:56 [dump_input.py:72] Dumping scheduler output for model execution:
ERROR 03-07 12:55:56 [dump_input.py:73] SchedulerOutput(scheduled_new_reqs=[NewRequestData(req_id=0,prompt_token_ids_len=6,prompt='',mm_inputs=[],mm_hashes=[],mm_positions=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None),block_ids=[0, 1, 2, 3, 4],num_computed_tokens=0,lora_request=None), NewRequestData(req_id=1,prompt_token_ids_len=8,prompt='',mm_inputs=[],mm_hashes=[],mm_positions=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None),block_ids=[5, 6, 7, 8, 9],num_computed_tokens=0,lora_request=None), NewRequestData(req_id=2,prompt_token_ids_len=6,prompt='',mm_inputs=[],mm_hashes=[],mm_positions=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None),block_ids=[10, 11, 12, 13, 14],num_computed_tokens=0,lora_request=None), NewRequestData(req_id=3,prompt_token_ids_len=6,prompt='',mm_inputs=[],mm_hashes=[],mm_positions=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=200, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None),block_ids=[15, 16, 17, 18, 19],num_computed_tokens=0,lora_request=None)],scheduled_cached_reqs=[],num_scheduled_tokens={0: 6, 1: 8, 3: 6, 2: 6},total_num_scheduled_tokens=26,scheduled_spec_decode_tokens={},scheduled_encoder_inputs={},num_common_prefix_blocks=0,finished_req_ids=[],free_encoder_input_ids=[])

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 17, 2025
Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
@joerunde joerunde added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 18, 2025
@tjohnson31415
Copy link
Contributor

@wallashss Thanks for writing up this PR. I think it will be useful to have details for debugging printed to the logs at crash!

When I try out these changes in my dev environment running online mode with

vllm serve meta-llama/Llama-3.2-11B-Vision-Instruct --max-num-seqs 4 --enforce-eager  --max-model-len 8192

and sending a request with a large prompt and requesting prompt_logprobs to trigger an OOM:

curl http://localhost:8000/v1/completions     -H "Content-Type: application/json"     -d '{
        "model": "model",
        "prompt": "'"$(seq -s ' ' 1 1500)"'",
        "max_tokens": 100,
        "prompt_logprobs": 10
    }'

I see the ModelExecutionError error be raised, but then the server seems to hang, never dumping the debug info or exiting... The logs in this case look like:

ERROR 02-20 20:47:16 engine.py:139] vllm.worker.worker_base.ModelExecutionError: Model execution failure,reason: OutOfMemoryError('CUDA out of memory. Tried to allocate 3.35 GiB. GPU 0 has a total capacity of 79.14 GiB of which 1.11 GiB is free. Process 3865663 has 78.01 GiB memory in use. Of the allocated memory 77.27 GiB is allocated by PyTorch, and 246.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)')
[rank0]:[W220 20:47:17.410788625 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
   <<and after a few seconds of hanging>>
ERROR 02-20 20:47:25 client.py:300] RuntimeError('Engine process (pid 4107) died.')
ERROR 02-20 20:47:25 client.py:300] NoneType: None

The above seems to happen only when the first reques to the server crashes it. If I send a shortened request first (e.g. prompt from seq -s ' ' 1 100), then it does actually crash on the second request but with an exception in the error reporter:

ERROR 02-20 21:05:09 engine.py:139] During handling of the above exception, another exception occurred:
ERROR 02-20 21:05:09 engine.py:139] 
ERROR 02-20 21:05:09 engine.py:139] Traceback (most recent call last):
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 137, in start
ERROR 02-20 21:05:09 engine.py:139]     self.run_engine_loop()
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 200, in run_engine_loop
ERROR 02-20 21:05:09 engine.py:139]     request_outputs = self.engine_step()
ERROR 02-20 21:05:09 engine.py:139]                       ^^^^^^^^^^^^^^^^^^
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 218, in engine_step
ERROR 02-20 21:05:09 engine.py:139]     raise e
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 209, in engine_step
ERROR 02-20 21:05:09 engine.py:139]     return self.engine.step()
ERROR 02-20 21:05:09 engine.py:139]            ^^^^^^^^^^^^^^^^^^
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/engine/llm_engine.py", line 1393, in step
ERROR 02-20 21:05:09 engine.py:139]     dump_engine_exception(
ERROR 02-20 21:05:09 engine.py:139]   File "/workspace/my-vllm/lib64/python3.12/site-packages/vllm/error_report.py", line 124, in dump_engine_exception
ERROR 02-20 21:05:09 engine.py:139]     str(len(r.seq_data[idx].prompt_token_ids))
ERROR 02-20 21:05:09 engine.py:139]             ~~~~~~~~~~^^^^^
ERROR 02-20 21:05:09 engine.py:139] KeyError: 0

@wallashss
Copy link
Contributor Author

Tks for try it out @tjohnson31415, I'll take a look in you examples and address these issues.

@wallashss
Copy link
Contributor Author

wallashss commented Feb 25, 2025

Hey @tjohnson31415 I think I addressed your issues:

I see the ModelExecutionError error be raised, but then the server seems to hang, never dumping the debug info or exiting... The logs in this case look like:

I repro your setup and I saw the logs, the dump is before the stacktrace. The server was indeed hanging, the issue was a pickling of my custom error, I fixed that.

The above seems to happen only when the first reques to the server crashes it. If I send a shortened request first (e.g. prompt from seq -s ' ' 1 100), then it does actually crash on the second request but with an exception in the error reporter:

Fixed that too, thanks for reporting this.

@wallashss
Copy link
Contributor Author

Hey @comaniac, do you think this feature implemented this way makes sense for vLLM?

I'd love to hear your feedback.

Thanks!

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea looks good. Let's polish the code and I'm ok with a green light. The most important point I'd like to highlight again is I feel we only need to support this feature in v1 for simplicity.

return False


def prepare_object_to_dump(obj):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is basically the root cause of removing the previous input dump. It has many cases to handle and will be affected if any of them is changed. Specifically, primitive types and torch.Tensor are fine, but I'm a bit worry about SequenceData and NewRequestData.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I share the same concern.

The custom handler for theses classes is to obfuscate the prompts. But I can not anticipate that we always have the right implementation for future changes. I guess we could add more hardcoded logs, comments, asserts, and tests to warn other developers of this feature at the cost of increase the maintenance of this feature. I am not sure of this, but I would like to hear more feedback or ideas from your side.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems create burden to developers and this is the reason of removing input dump. Ideally we could have an approach to recursively traverse an input object and serialize them with tensor values ignored. Another direction is providing these methods in custom data structures (e.g., SequenceData.dump()) so that they can be in the same place to ease the maintenance.

Also cc @youkaichao

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a method called anon_repr for those classes, which is similar to the __repr__ implementation. They are close and I added comment there to help other contributors to be aware of that. The prepare_object_to_dump is has indirect awareness of this method, it check in the serialization if the object contains this method, and use it if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW: I changed the format of dump to be more like how __repr__ outputs string representation of objects instead of JSON. I think it got more standardized and consistent with what we already have been using with __repr__.

execute_model_req: Union[ExecuteModelRequest,
None] = None):

assert engine_version == 0 or engine_version == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we don't need to support v0. Reasons:

  1. The code could be much cleaner.
  2. v0 is going to be deprecated soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but I guess that are still a lot of deployments running right now that are based on V0 (at least from our side). That's why we are interested in support both engines.

Do you think you can reconsider?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're planning to freeze v0 I still don't feel we should support it. However if you really need, I'd suggest that we separate the v0/v1 logic completely in different functions (e.g. xxx_v0), so that in the future when we want to deprecate v0, we can easily locate the logic and remove them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I splitted the functions to ease the identification of the version.

@mergify
Copy link

mergify bot commented Feb 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wallashss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 27, 2025
Signed-off-by: Wallas Santos <[email protected]>
@mergify mergify bot removed the needs-rebase label Mar 5, 2025
@mergify
Copy link

mergify bot commented Mar 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wallashss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 28, 2025
@mergify mergify bot removed the needs-rebase label Apr 2, 2025
Signed-off-by: Wallas Santos <[email protected]>
added logs for scheduler stats
minor fixes

Signed-off-by: Wallas Santos <[email protected]>
@wallashss
Copy link
Contributor Author

Finally all green!

This PR has been there for a while, and now I am convinced to remove the V0 support. I also did some minor updates. If any of you guys could have a second look and merge I'd appreciate it. Thanks!

@youkaichao @njhill @comaniac

@mergify mergify bot added the tpu Related to Google TPUs label Apr 9, 2025
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wallashss, and sorry I have a couple more comments!

Comment on lines 53 to 79
class ModelExecutionError(RuntimeError):
"""Custom RuntimeError with input data for model execution
In a nutshell, this object is useful for custom handling of exception for
the case the engine raises an error. For instance, it is used to log the
input metadata that is useful for debugging on engine crashes.
Args:
scheduler_output: SchedulerOutput object that contains the input
data for model execution
"""
scheduler_output: SchedulerOutput
scheduler_stats: SchedulerStats

def __init__(self, *args, scheduler_output=None, scheduler_stats=None):
super().__init__(*args)
self.scheduler_output = scheduler_output
self.scheduler_stats = scheduler_stats

def __reduce__(self):
# To avoid pickle errors.
# This happens when we exchange this object between processes
# since scheduler_output can have objects that only makes sense
# to their context/process we remove them from the serialization
# and only send the summary of the error as a regular RuntimeError.
return (self.__class__, (self.args[0], ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this class to dump_input.py to keep core.py cleaner?

Copy link
Contributor Author

@wallashss wallashss Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing v0 and reviewing it, I think we can totally remove this class and get a cleaner code.

@mergify mergify bot removed the tpu Related to Google TPUs label Apr 9, 2025
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wallashss, sorry for forgetting to come back to this. It looks good to me I just have one more minor comment.

@wallashss
Copy link
Contributor Author

Thanks @njhill !

Could you please point me out the comment? Is that a new one? Or did I missed this?

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wallashss apologies I must have failed to press the button to include the comment in the review.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wallashss sorry for the extra nits, just want to keep the core engine loop as clean as possible. I will merge as soon as these are in!

Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
@wallashss wallashss force-pushed the dump-input-on-crash branch from 104a8b8 to 51596e4 Compare May 7, 2025 19:14
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wallashss

@njhill njhill enabled auto-merge (squash) May 7, 2025 20:01
@njhill njhill merged commit d43f914 into vllm-project:main May 7, 2025
51 checks passed
@wallashss
Copy link
Contributor Author

Thanks a lot @njhill !

princepride pushed a commit to princepride/vllm that referenced this pull request May 10, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants