-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Self-speculation (Layer-Skip Llama) #34240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
47a0963
42f0df6
317edb6
c767197
2bf9c78
c267831
3d24097
6d6ee24
2ab1726
ec9dfa2
7a8cda0
8dfa29a
cc38367
a480556
7850f02
a1aa8c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -433,19 +433,22 @@ def update( | |
| self._seen_tokens += key_states.shape[-2] | ||
|
|
||
| # Update the cache | ||
| if len(self.key_cache) <= layer_idx: | ||
| # There may be skipped layers, fill them with empty lists | ||
| for _ in range(len(self.key_cache), layer_idx): | ||
| self.key_cache.append([]) | ||
| self.value_cache.append([]) | ||
| self.key_cache.append(key_states) | ||
| self.value_cache.append(value_states) | ||
| elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors | ||
| self.key_cache[layer_idx] = key_states | ||
| self.value_cache[layer_idx] = value_states | ||
| else: | ||
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
| if key_states is not None: | ||
| if len(self.key_cache) <= layer_idx: | ||
| # There may be skipped layers, fill them with empty lists | ||
| for _ in range(len(self.key_cache), layer_idx): | ||
| self.key_cache.append([]) | ||
| self.value_cache.append([]) | ||
| self.key_cache.append(key_states) | ||
| self.value_cache.append(value_states) | ||
| elif ( | ||
| len(self.key_cache[layer_idx]) == 0 | ||
| ): # fills previously skipped layers; checking for tensor causes errors | ||
| self.key_cache[layer_idx] = key_states | ||
| self.value_cache[layer_idx] = value_states | ||
| else: | ||
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
|
Comment on lines
+450
to
+451
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that is correct! I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :) |
||
|
|
||
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -670,6 +670,61 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F | |
| return | ||
|
|
||
|
|
||
| class EarlyExitCandidateGenerator(AssistedCandidateGenerator): | ||
| """ | ||
| `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates | ||
| candidates through the use of **the model itself**, exiting early. Can only be used with models that support early | ||
| exit, e.g., `facebook/layerskip-llama3.2-1B`. | ||
|
|
||
| Args: | ||
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||
| Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | ||
| assistant_model (`PreTrainedModel`): | ||
| The original model. This model must support early exit (i.e. is trained to compute logits in earlier | ||
| layers). | ||
| generation_config (`~generation.GenerationConfig`, *optional*): | ||
| The generation configuration to be used as base parametrization for the generation call. | ||
| logits_processor (`LogitsProcessorList`): | ||
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | ||
| used to modify the prediction scores of the language modeling head applied at each generation step. | ||
| model_kwargs (`Dict`): | ||
| The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant | ||
| model as well. | ||
| inputs_tensor (`torch.Tensor`, *optional*): | ||
| The model input tensor. In encoder-decoder models, this is the encoder input. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| assistant_model: "PreTrainedModel", | ||
| generation_config: "GenerationConfig", | ||
| model_kwargs: Dict, | ||
| inputs_tensor: Optional[torch.Tensor] = None, | ||
| logits_processor: "LogitsProcessorList" = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, I recently also added
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can ignore my comment about supporting |
||
| ): | ||
| super().__init__( | ||
| input_ids=input_ids, | ||
| assistant_model=assistant_model, | ||
| generation_config=generation_config, | ||
| model_kwargs=model_kwargs, | ||
| inputs_tensor=inputs_tensor, | ||
| logits_processor=logits_processor, | ||
| ) | ||
| # We have to move early exit out of the generation config, otherwise the assistant will also call `generate` | ||
| # with early exit | ||
| self.assistant_early_exit = self.generation_config.assistant_early_exit | ||
| self.generation_config.assistant_early_exit = None | ||
|
|
||
| def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | ||
| # Temporarily sets the number of hidden layers to the early exit value | ||
| base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix) | ||
| base_model.num_hidden_layers = self.assistant_early_exit | ||
| candidate_ids, candidate_logits = super().get_candidates(input_ids) | ||
gante marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| base_model.num_hidden_layers = base_model.config.num_hidden_layers | ||
| return candidate_ids, candidate_logits | ||
|
|
||
|
|
||
| def _crop_past_key_values(model, past_key_values, max_length): | ||
| """Crops the past key values up to a certain maximum length.""" | ||
| new_past = [] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -803,6 +803,7 @@ def __init__(self, config: CohereConfig): | |
| super().__init__(config) | ||
| self.padding_idx = config.pad_token_id | ||
| self.vocab_size = config.vocab_size | ||
| self.num_hidden_layers = config.num_hidden_layers | ||
|
|
||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | ||
| self.layers = nn.ModuleList( | ||
|
|
@@ -890,7 +891,7 @@ def forward( | |
| all_self_attns = () if output_attentions else None | ||
| next_decoder_cache = None | ||
|
|
||
| for decoder_layer in self.layers: | ||
| for decoder_layer in self.layers[: self.num_hidden_layers]: | ||
|
||
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nesting was not fully right -- normal "speculative decoding" examples were under "Universal Assisted Decoding". Moved a few things around)