From 074815c7e0e32e42dbb78528fa2fc59db3016b34 Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 27 Nov 2024 04:28:24 -0800 Subject: [PATCH] fix negative max_new_tokens bug --- src/transformers/generation/candidate_generator.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 7527054bd93a..ccd4b5508410 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -709,8 +709,6 @@ def __init__( # Track sequence lengths and previous assistant IDs self._prev_target_seq_len: int = 0 self._prev_assistant_ids: Optional[torch.LongTensor] = None - # generation max length according to the assistant vocabulary - self.assistant_generation_max_length = -1 def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ @@ -719,14 +717,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, input_ids = input_ids.to(self.assistant_model.device) target_input_ids = input_ids.clone() assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) - if self.assistant_generation_max_length == -1: - self.assistant_generation_max_length = self.generation_config.max_length - input_ids.shape[1] + assistant_input_ids.shape[1] - - # Standard generation steps - target_generation_max_length = self.generation_config.max_length - self.generation_config.max_length = self.assistant_generation_max_length - min_new_tokens, max_new_tokens = self._calculate_new_tokens(assistant_input_ids) - self.generation_config.max_length = target_generation_max_length + min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) if max_new_tokens == 0: return input_ids, None