diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 621edeb20e8e..380b39fe62ac 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer, Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs. To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation). -#### Universal Assisted Decoding - -Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers. -To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below). -Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are -in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. -The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. -Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, -to ensure the new tokens include the correct prompt suffix. - To enable assisted decoding, set the `assistant_model` argument with a model. ```python @@ -445,7 +435,36 @@ To enable assisted decoding, set the `assistant_model` argument with a model. ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -If the main and assistant models have different tokenizers, use Universal Assisted Decoding. +When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, +just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed +>>> set_seed(42) # For reproducibility + +>>> prompt = "Alice and Bob" +>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" +>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" + +>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) +>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) +>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Alice and Bob, a couple of friends of mine, who are both in the same office as'] +``` + +#### Universal Assisted Decoding + +Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers. +To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below). +Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are +in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above. +The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer. +Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings, +to ensure the new tokens include the correct prompt suffix. ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer @@ -465,30 +484,35 @@ If the main and assistant models have different tokenizers, use Universal Assist ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness, -just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency. +#### Prompt Lookup + +Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed +to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). + +#### Self-Speculative Decoding + +An LLM can be trained to also use its language modeling head with earlier hidden states as input, effectively +skipping layers to yield a lower-quality output -- a technique called early exiting. +We use the lower-quality early exit output as an assistant output, and apply self-speculation to fix the output using the remaining layers. The final generation of that self-speculative solution is the same (or has the same distribution) as the original model's generation. +If the model you're using was trained to do early exit, you can pass +`assistant_early_exit` (integer). In this case, the assistant model will be the same model but exiting early, hence the +"self-speculative" name. Because the assistant model is a portion of the target model, caches and weights can be shared, which results in lower memory requirements. As in other assisted generation methods, the final generated result has the same quality as if no assistant had been used. ```python ->>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed ->>> set_seed(42) # For reproducibility +>>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> prompt = "Alice and Bob" ->>> checkpoint = "EleutherAI/pythia-1.4b-deduped" ->>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" +>>> checkpoint = "facebook/layerskip-llama3.2-1B" >>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) >>> inputs = tokenizer(prompt, return_tensors="pt") >>> model = AutoModelForCausalLM.from_pretrained(checkpoint) ->>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) ->>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) +>>> outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) -['Alice and Bob, a couple of friends of mine, who are both in the same office as'] +['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` -Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed -to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). - ### DoLa Decoding **D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0f696cc3ac6a..aeb184f7400c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -433,19 +433,22 @@ def update( self._seen_tokens += key_states.shape[-2] # Update the cache - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append([]) - self.value_cache.append([]) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append([]) + self.value_cache.append([]) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + len(self.key_cache[layer_idx]) == 0 + ): # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index b487fa3c7fe6..e2ed48433b16 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -49,6 +49,7 @@ _import_structure["candidate_generator"] = [ "AssistedCandidateGenerator", "CandidateGenerator", + "EarlyExitCandidateGenerator", "PromptLookupCandidateGenerator", ] _import_structure["logits_process"] = [ @@ -206,7 +207,12 @@ else: from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer - from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator + from .candidate_generator import ( + AssistedCandidateGenerator, + CandidateGenerator, + EarlyExitCandidateGenerator, + PromptLookupCandidateGenerator, + ) from .logits_process import ( AlternatingCodebooksLogitsProcessor, ClassifierFreeGuidanceLogitsProcessor, diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1e4d7a470245..d8344c25a652 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -670,6 +670,62 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F return +class EarlyExitCandidateGenerator(AssistedCandidateGenerator): + """ + `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates + candidates through the use of **the model itself**, exiting early. Can only be used with models that support early + exit, e.g., `facebook/layerskip-llama3.2-1B`. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_model (`PreTrainedModel`): + The original model. This model must support early exit (i.e. is trained to compute logits in earlier + layers). + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + model_kwargs (`Dict`): + The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant + model as well. + inputs_tensor (`torch.Tensor`, *optional*): + The model input tensor. In encoder-decoder models, this is the encoder input. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + generation_config: "GenerationConfig", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + logits_processor: "LogitsProcessorList" = None, + ): + super().__init__( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + # We have to move early exit out of the generation config, otherwise the assistant will also call `generate` + # with early exit + self.assistant_early_exit = self.generation_config.assistant_early_exit + self.generation_config.assistant_early_exit = None + + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + # Temporarily sets the number of hidden layers to the early exit value + base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix) + original_num_hidden_layers = base_model.config.num_hidden_layers + base_model.config.num_hidden_layers = self.assistant_early_exit + candidate_ids, candidate_logits = super().get_candidates(input_ids) + base_model.config.num_hidden_layers = original_num_hidden_layers + return candidate_ids, candidate_logits + + def _crop_past_key_values(model, past_key_values, max_length): """Crops the past key values up to a certain maximum length.""" new_past = [] diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 9b543f6c3571..de62ee767aed 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -353,10 +353,13 @@ class GenerationConfig(PushToHubMixin): than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models . - prompt_lookup_num_tokens (`int`, *optional*, default to `None`): + prompt_lookup_num_tokens (`int`, *optional*): The number of tokens to be output as candidate tokens. - max_matching_ngram_size (`int`, *optional*, default to `None`): + max_matching_ngram_size (`int`, *optional*): The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. + assistant_early_exit(`int`, *optional*): + If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with + models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head). > Wild card @@ -454,10 +457,9 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant") self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4) - - # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) + self.assistant_early_exit = kwargs.pop("assistant_early_exit", None) # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) @@ -534,7 +536,11 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non generation_mode = GenerationMode.BEAM_SEARCH # Assisted generation may extend some generation modes - if assistant_model is not None or self.prompt_lookup_num_tokens is not None: + if ( + assistant_model is not None + or self.prompt_lookup_num_tokens is not None + or self.assistant_early_exit is not None + ): if generation_mode in ("greedy_search", "sample"): generation_mode = GenerationMode.ASSISTED_GENERATION else: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6e6d5b8bdce7..c37ec64c085e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -54,6 +54,7 @@ AssistedCandidateGenerator, AssistedCandidateGeneratorDifferentTokenizers, CandidateGenerator, + EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, _crop_past_key_values, _prepare_attention_mask, @@ -822,7 +823,16 @@ def _get_candidate_generator( """ different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) - if generation_config.prompt_lookup_num_tokens is not None: + if generation_config.assistant_early_exit is not None: + candidate_generator = EarlyExitCandidateGenerator( + input_ids=input_ids, + assistant_model=self, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + elif generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b215fb6561bf..52d9420cf861 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -890,7 +890,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index fa3fadc4349a..9bb4b63f3687 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -805,7 +805,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 807f91ff9e6b..ad1348ae5e31 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -886,7 +886,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 626e5537fc06..cec05ab16ef5 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -820,7 +820,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index dacaca1c7ef4..ff2d42d671c3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -653,7 +653,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 248ec4021791..edec4e173e8f 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -787,7 +787,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4d95f01849d6..1b045f11e733 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -930,7 +930,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index cbb8db0f59dd..765092f32c4e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -995,7 +995,7 @@ def forward( all_router_logits = () if output_router_logits else None next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cbe851e97e9a..6630fc2ba9d1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -4108,6 +4108,28 @@ def test_generate_compile_fullgraph_tiny(self): gen_out = compiled_generate(**model_inputs, generation_config=generation_config) self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated + def test_assisted_generation_early_exit(self): + """ + Tests that assisted generation with early exit works as expected. Under the hood, this has complex cache + manipulation, which will cause the test to fail if something goes wrong there. + """ + expected_output = "Alice and Bob are playing a game of poker. Alice has a pair of 8s and Bob has a pair" + + prompt = "Alice and Bob" + checkpoint = "facebook/layerskip-llama3.2-1B" + + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + inputs = tokenizer(prompt, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained(checkpoint).to(torch_device) + original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20) + original_decoded = tokenizer.batch_decode(original_outputs, skip_special_tokens=True) + self.assertEqual(original_decoded, [expected_output]) + + outputs_assisted = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20) + decoded_assisted = tokenizer.batch_decode(outputs_assisted, skip_special_tokens=True) + self.assertEqual(decoded_assisted, [expected_output]) + @require_torch class TokenHealingTestCase(unittest.TestCase):