From 299faab763369b1b9a4b5715c77b1daff2dfde96 Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Thu, 5 Dec 2024 20:03:01 -0500 Subject: [PATCH] [WIP] drafting a fix - cropping the kv cache --- src/transformers/generation/candidate_generator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 25148b3a6b83..f1ab4b5992ef 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -717,13 +717,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, """ input_ids = input_ids.to(self.assistant_model.device) target_input_ids = input_ids.clone() - assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) + assistant_input_ids, remove_from_kv = self._prepare_assistant_input_ids( + target_input_ids + ) min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) if max_new_tokens == 0: return input_ids, None - self._update_past_and_masks(assistant_input_ids) + self._update_past_and_masks(assistant_input_ids, remove_from_kv) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) self.assistant_kwargs.pop("attention_mask", None) @@ -745,12 +747,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, return target_ids, target_logits - def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: + def _prepare_assistant_input_ids( + self, target_input_ids: torch.LongTensor + ) -> Tuple[torch.LongTensor, int]: """ Simplified token conversion that only processes new tokens. """ # Calculate new tokens since last call target_seq_len = target_input_ids.shape[-1] + remove_from_pkv = target_seq_len - 1 - self._prev_target_seq_len new_token_count = target_seq_len - self._prev_target_seq_len target_new_ids = target_input_ids[:, -new_token_count:] self._prev_target_seq_len = target_seq_len @@ -769,7 +774,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to else: self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) - return self._prev_assistant_ids + return self._prev_assistant_ids, remove_from_pkv class PromptLookupCandidateGenerator(CandidateGenerator):