Skip to content

Commit 54739a3

Browse files
ArthurZuckergantemostafaelhoushipcuenca
authored
Self-speculation (Layer-Skip Llama) (#34240)
* 😅 * early exit (#34244) * mvp * docs and tests * a few fixes * no shared cache * Apply suggestions from code review Co-authored-by: Mostafa Elhoushi <[email protected]> * docs * make fix-copies * cohere fix * [test all] * [test all] consistent model code copies * [test all] make fix-copies :D * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/configuration_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * [test all] don't use a stand-alone attribute; fix test --------- Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 5de58d5 commit 54739a3

File tree

15 files changed

+178
-51
lines changed

15 files changed

+178
-51
lines changed

docs/source/en/generation_strategies.md

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -416,16 +416,6 @@ Assisted decoding assumes the main and assistant models have the same tokenizer,
416416
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
417417
To learn more about assisted decoding, check [this blog post](https://huggingface.co/blog/assisted-generation).
418418

419-
#### Universal Assisted Decoding
420-
421-
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
422-
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
423-
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
424-
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
425-
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
426-
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
427-
to ensure the new tokens include the correct prompt suffix.
428-
429419
To enable assisted decoding, set the `assistant_model` argument with a model.
430420

431421
```python
@@ -445,7 +435,36 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
445435
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
446436
```
447437

448-
If the main and assistant models have different tokenizers, use Universal Assisted Decoding.
438+
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
439+
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
440+
441+
```python
442+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
443+
>>> set_seed(42) # For reproducibility
444+
445+
>>> prompt = "Alice and Bob"
446+
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
447+
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
448+
449+
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
450+
>>> inputs = tokenizer(prompt, return_tensors="pt")
451+
452+
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
453+
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
454+
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
455+
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
456+
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
457+
```
458+
459+
#### Universal Assisted Decoding
460+
461+
Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
462+
To use it, simply pass the tokenizers using the `tokenizer` and `assistant_tokenizer` arguments (see below).
463+
Internally, the main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
464+
in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
465+
The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
466+
Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
467+
to ensure the new tokens include the correct prompt suffix.
449468

450469
```python
451470
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -465,30 +484,35 @@ If the main and assistant models have different tokenizers, use Universal Assist
465484
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
466485
```
467486

468-
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
469-
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
487+
#### Prompt Lookup
488+
489+
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
490+
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
491+
492+
#### Self-Speculative Decoding
493+
494+
An LLM can be trained to also use its language modeling head with earlier hidden states as input, effectively
495+
skipping layers to yield a lower-quality output -- a technique called early exiting.
496+
We use the lower-quality early exit output as an assistant output, and apply self-speculation to fix the output using the remaining layers. The final generation of that self-speculative solution is the same (or has the same distribution) as the original model's generation.
497+
If the model you're using was trained to do early exit, you can pass
498+
`assistant_early_exit` (integer). In this case, the assistant model will be the same model but exiting early, hence the
499+
"self-speculative" name. Because the assistant model is a portion of the target model, caches and weights can be shared, which results in lower memory requirements. As in other assisted generation methods, the final generated result has the same quality as if no assistant had been used.
470500

471501
```python
472-
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
473-
>>> set_seed(42) # For reproducibility
502+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
474503

475504
>>> prompt = "Alice and Bob"
476-
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
477-
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
505+
>>> checkpoint = "facebook/layerskip-llama3.2-1B"
478506

479507
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
480508
>>> inputs = tokenizer(prompt, return_tensors="pt")
481509

482510
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
483-
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
484-
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
511+
>>> outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_new_tokens=20)
485512
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
486-
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
513+
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
487514
```
488515

489-
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
490-
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
491-
492516
### DoLa Decoding
493517

494518
**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the

src/transformers/cache_utils.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -433,19 +433,22 @@ def update(
433433
self._seen_tokens += key_states.shape[-2]
434434

435435
# Update the cache
436-
if len(self.key_cache) <= layer_idx:
437-
# There may be skipped layers, fill them with empty lists
438-
for _ in range(len(self.key_cache), layer_idx):
439-
self.key_cache.append([])
440-
self.value_cache.append([])
441-
self.key_cache.append(key_states)
442-
self.value_cache.append(value_states)
443-
elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
444-
self.key_cache[layer_idx] = key_states
445-
self.value_cache[layer_idx] = value_states
446-
else:
447-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
448-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
436+
if key_states is not None:
437+
if len(self.key_cache) <= layer_idx:
438+
# There may be skipped layers, fill them with empty lists
439+
for _ in range(len(self.key_cache), layer_idx):
440+
self.key_cache.append([])
441+
self.value_cache.append([])
442+
self.key_cache.append(key_states)
443+
self.value_cache.append(value_states)
444+
elif (
445+
len(self.key_cache[layer_idx]) == 0
446+
): # fills previously skipped layers; checking for tensor causes errors
447+
self.key_cache[layer_idx] = key_states
448+
self.value_cache[layer_idx] = value_states
449+
else:
450+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
451+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
449452

450453
return self.key_cache[layer_idx], self.value_cache[layer_idx]
451454

src/transformers/generation/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_import_structure["candidate_generator"] = [
5050
"AssistedCandidateGenerator",
5151
"CandidateGenerator",
52+
"EarlyExitCandidateGenerator",
5253
"PromptLookupCandidateGenerator",
5354
]
5455
_import_structure["logits_process"] = [
@@ -206,7 +207,12 @@
206207
else:
207208
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
208209
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
209-
from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
210+
from .candidate_generator import (
211+
AssistedCandidateGenerator,
212+
CandidateGenerator,
213+
EarlyExitCandidateGenerator,
214+
PromptLookupCandidateGenerator,
215+
)
210216
from .logits_process import (
211217
AlternatingCodebooksLogitsProcessor,
212218
ClassifierFreeGuidanceLogitsProcessor,

src/transformers/generation/candidate_generator.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,62 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
670670
return
671671

672672

673+
class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
674+
"""
675+
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
676+
candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
677+
exit, e.g., `facebook/layerskip-llama3.2-1B`.
678+
679+
Args:
680+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
681+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
682+
assistant_model (`PreTrainedModel`):
683+
The original model. This model must support early exit (i.e. is trained to compute logits in earlier
684+
layers).
685+
generation_config (`~generation.GenerationConfig`, *optional*):
686+
The generation configuration to be used as base parametrization for the generation call.
687+
logits_processor (`LogitsProcessorList`):
688+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
689+
used to modify the prediction scores of the language modeling head applied at each generation step.
690+
model_kwargs (`Dict`):
691+
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
692+
model as well.
693+
inputs_tensor (`torch.Tensor`, *optional*):
694+
The model input tensor. In encoder-decoder models, this is the encoder input.
695+
"""
696+
697+
def __init__(
698+
self,
699+
input_ids: torch.LongTensor,
700+
assistant_model: "PreTrainedModel",
701+
generation_config: "GenerationConfig",
702+
model_kwargs: Dict,
703+
inputs_tensor: Optional[torch.Tensor] = None,
704+
logits_processor: "LogitsProcessorList" = None,
705+
):
706+
super().__init__(
707+
input_ids=input_ids,
708+
assistant_model=assistant_model,
709+
generation_config=generation_config,
710+
model_kwargs=model_kwargs,
711+
inputs_tensor=inputs_tensor,
712+
logits_processor=logits_processor,
713+
)
714+
# We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
715+
# with early exit
716+
self.assistant_early_exit = self.generation_config.assistant_early_exit
717+
self.generation_config.assistant_early_exit = None
718+
719+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
720+
# Temporarily sets the number of hidden layers to the early exit value
721+
base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
722+
original_num_hidden_layers = base_model.config.num_hidden_layers
723+
base_model.config.num_hidden_layers = self.assistant_early_exit
724+
candidate_ids, candidate_logits = super().get_candidates(input_ids)
725+
base_model.config.num_hidden_layers = original_num_hidden_layers
726+
return candidate_ids, candidate_logits
727+
728+
673729
def _crop_past_key_values(model, past_key_values, max_length):
674730
"""Crops the past key values up to a certain maximum length."""
675731
new_past = []

src/transformers/generation/configuration_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,13 @@ class GenerationConfig(PushToHubMixin):
353353
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
354354
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
355355
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
356-
prompt_lookup_num_tokens (`int`, *optional*, default to `None`):
356+
prompt_lookup_num_tokens (`int`, *optional*):
357357
The number of tokens to be output as candidate tokens.
358-
max_matching_ngram_size (`int`, *optional*, default to `None`):
358+
max_matching_ngram_size (`int`, *optional*):
359359
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
360+
assistant_early_exit(`int`, *optional*):
361+
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
362+
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
360363
361364
> Wild card
362365
@@ -454,10 +457,9 @@ def __init__(self, **kwargs):
454457
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
455458
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
456459
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
457-
458-
# Prompt lookup decoding
459460
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
460461
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
462+
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
461463

462464
# Wild card
463465
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
@@ -534,7 +536,11 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non
534536
generation_mode = GenerationMode.BEAM_SEARCH
535537

536538
# Assisted generation may extend some generation modes
537-
if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
539+
if (
540+
assistant_model is not None
541+
or self.prompt_lookup_num_tokens is not None
542+
or self.assistant_early_exit is not None
543+
):
538544
if generation_mode in ("greedy_search", "sample"):
539545
generation_mode = GenerationMode.ASSISTED_GENERATION
540546
else:

src/transformers/generation/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
AssistedCandidateGenerator,
5555
AssistedCandidateGeneratorDifferentTokenizers,
5656
CandidateGenerator,
57+
EarlyExitCandidateGenerator,
5758
PromptLookupCandidateGenerator,
5859
_crop_past_key_values,
5960
_prepare_attention_mask,
@@ -822,7 +823,16 @@ def _get_candidate_generator(
822823
"""
823824
different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
824825

825-
if generation_config.prompt_lookup_num_tokens is not None:
826+
if generation_config.assistant_early_exit is not None:
827+
candidate_generator = EarlyExitCandidateGenerator(
828+
input_ids=input_ids,
829+
assistant_model=self,
830+
generation_config=generation_config,
831+
model_kwargs=model_kwargs,
832+
inputs_tensor=inputs_tensor,
833+
logits_processor=logits_processor,
834+
)
835+
elif generation_config.prompt_lookup_num_tokens is not None:
826836
candidate_generator = PromptLookupCandidateGenerator(
827837
eos_token_id=generation_config._eos_token_tensor,
828838
num_output_tokens=generation_config.prompt_lookup_num_tokens,

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def forward(
890890
all_self_attns = () if output_attentions else None
891891
next_decoder_cache = None
892892

893-
for decoder_layer in self.layers:
893+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
894894
if output_hidden_states:
895895
all_hidden_states += (hidden_states,)
896896

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ def forward(
808808
all_self_attns = () if output_attentions else None
809809
next_decoder_cache = None
810810

811-
for decoder_layer in self.layers:
811+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
812812
if output_hidden_states:
813813
all_hidden_states += (hidden_states,)
814814

src/transformers/models/gemma/modular_gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@ def forward(
886886
all_self_attns = () if output_attentions else None
887887
next_decoder_cache = None
888888

889-
for decoder_layer in self.layers:
889+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
890890
if output_hidden_states:
891891
all_hidden_states += (hidden_states,)
892892

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def forward(
823823
all_hidden_states = () if output_hidden_states else None
824824
all_self_attns = () if output_attentions else None
825825

826-
for decoder_layer in self.layers:
826+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
827827
if output_hidden_states:
828828
all_hidden_states += (hidden_states,)
829829

0 commit comments

Comments
 (0)