-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Bugfix] Do not crash V0 engine on input errors #13101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
df0fb3f
36ac644
54cd257
aeaf2eb
234a826
49ad2b1
e90af36
5ba9c4c
f7ed8c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,6 +75,24 @@ | |
| torch._dynamo.config.accumulated_cache_size_limit = 128 | ||
|
|
||
|
|
||
| class InputProcessingError(Exception): | ||
| """This exception is raised when an error occurs preparing the inputs for | ||
| a single sequence group. | ||
| This allows the engine to gracefully handle errors with a single sequence | ||
| group without having to fail the entire batch. | ||
| """ | ||
|
|
||
| def __init__(self, request_id, message): | ||
| """request_id is the id of the offending sequence group""" | ||
| self.request_id = request_id | ||
| self.message = message | ||
| super().__init__(self.message) | ||
|
|
||
| def __str__(self): | ||
| return "Failed to prepare inputs for sequence group with request id: " \ | ||
| f"{self.request_id}, Error: {self.message}" | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ModelInputForGPU(ModelRunnerInputBase): | ||
| """ | ||
|
|
@@ -731,36 +749,42 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, | |
|
|
||
| def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): | ||
| """Add a sequence group to the builder.""" | ||
| seq_ids = seq_group_metadata.seq_data.keys() | ||
| n_seqs = len(seq_ids) | ||
| is_prompt = seq_group_metadata.is_prompt | ||
|
|
||
| if is_prompt: | ||
| assert n_seqs == 1 | ||
| self.decode_only = False | ||
|
|
||
| encoder_seq_len = 0 | ||
|
|
||
| if self.runner.model_config.is_encoder_decoder: | ||
| encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() | ||
|
|
||
| inter_data = self.init_cached_inter_data( | ||
| request_id=seq_group_metadata.request_id, | ||
| seq_ids=seq_ids, | ||
| is_prompt=is_prompt, | ||
| block_tables=seq_group_metadata.block_tables, | ||
| computed_block_nums=seq_group_metadata.computed_block_nums, | ||
| reinit=True, | ||
| reinit_use_defaults=True, | ||
| encoder_seq_len=encoder_seq_len) | ||
|
|
||
| self.inter_data_list.append(inter_data) | ||
|
|
||
| for seq_idx in range(n_seqs): | ||
| for per_seq_fn in self.per_seq_compute_fns: | ||
| per_seq_fn(inter_data, seq_idx, seq_group_metadata) | ||
| for per_seq_group_fn in self.per_seq_group_compute_fns: | ||
| per_seq_group_fn(inter_data, seq_group_metadata) | ||
|
|
||
| try: | ||
| seq_ids = seq_group_metadata.seq_data.keys() | ||
| n_seqs = len(seq_ids) | ||
| is_prompt = seq_group_metadata.is_prompt | ||
|
|
||
| if is_prompt: | ||
| assert n_seqs == 1 | ||
| self.decode_only = False | ||
|
|
||
| encoder_seq_len = 0 | ||
|
|
||
| if self.runner.model_config.is_encoder_decoder: | ||
| encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() | ||
|
|
||
| inter_data = self.init_cached_inter_data( | ||
| request_id=seq_group_metadata.request_id, | ||
| seq_ids=seq_ids, | ||
| is_prompt=is_prompt, | ||
| block_tables=seq_group_metadata.block_tables, | ||
| computed_block_nums=seq_group_metadata.computed_block_nums, | ||
| reinit=True, | ||
| reinit_use_defaults=True, | ||
| encoder_seq_len=encoder_seq_len) | ||
|
|
||
| self.inter_data_list.append(inter_data) | ||
|
|
||
| for seq_idx in range(n_seqs): | ||
| for per_seq_fn in self.per_seq_compute_fns: | ||
| per_seq_fn(inter_data, seq_idx, seq_group_metadata) | ||
| for per_seq_group_fn in self.per_seq_group_compute_fns: | ||
| per_seq_group_fn(inter_data, seq_group_metadata) | ||
| except Exception as e: | ||
| # Raise an exception that tracks the ID of the bad request | ||
| raise InputProcessingError(seq_group_metadata.request_id, | ||
| str(e)) from e | ||
|
||
|
|
||
| def _use_captured_graph(self, | ||
| batch_size: int, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.