From 5ee00e2b2d78c29e6a6a0e9671f9228ea0cfe7c6 Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Wed, 4 Dec 2024 20:56:51 -0500 Subject: [PATCH] introduce `self.prev_target_ids_len` --- src/transformers/generation/candidate_generator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 75ddc225c180..c690e0e3badd 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -319,7 +319,7 @@ def __init__( self.target_tokenizer = target_tokenizer self.assistant_tokenizer = assistant_tokenizer - self.prev_target_ids = None + self.prev_target_ids_len: Optional[int] = None self.prev_assistant_ids = None self.target_lookbehind = assistant_model.generation_config.target_lookbehind self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind @@ -465,11 +465,11 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids) # Update state - self.prev_target_ids = input_ids + self.prev_target_ids_len = input_ids.shape[1] self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values self.prev_assistant_ids = assistant_output.sequences - if input_ids.shape[1] >= new_target_ids.shape[1]: + if self.prev_target_ids_len >= new_target_ids.shape[1]: return input_ids, None return new_target_ids, None @@ -482,9 +482,9 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[tor } remove_from_pkv = 0 - if self.prev_assistant_ids is not None and self.prev_target_ids.shape[1] > self.target_lookbehind: + if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind: # input_ids contains all target prompt input ids and some new target input ids - start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind + start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind new_assistant_ids = self.convert_source_tokens_to_target_tokens( input_ids[:, start_index_in_target_window:], **convert_kwargs @@ -516,7 +516,7 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[tor assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1) else: assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs) - self.prev_target_ids = input_ids + self.prev_target_ids_len = input_ids.shape[1] return assistant_input_ids, remove_from_pkv