Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __init__(

self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_target_ids = None
self.prev_target_ids_len: Optional[int] = None
self.prev_assistant_ids = None
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
Expand Down Expand Up @@ -465,11 +465,11 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)

# Update state
self.prev_target_ids = input_ids
self.prev_target_ids_len = input_ids.shape[1]
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_assistant_ids = assistant_output.sequences

if input_ids.shape[1] >= new_target_ids.shape[1]:
if self.prev_target_ids_len >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None
Expand All @@ -482,9 +482,9 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[tor
}
remove_from_pkv = 0

if self.prev_assistant_ids is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind

new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
Expand Down Expand Up @@ -516,7 +516,7 @@ def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[tor
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids = input_ids
self.prev_target_ids_len = input_ids.shape[1]

return assistant_input_ids, remove_from_pkv

Expand Down