-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Self-speculation (Layer-Skip Llama) #34240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
47a0963
42f0df6
317edb6
c767197
2bf9c78
c267831
3d24097
6d6ee24
2ab1726
ec9dfa2
7a8cda0
8dfa29a
cc38367
a480556
7850f02
a1aa8c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -433,19 +433,22 @@ def update( | |
| self._seen_tokens += key_states.shape[-2] | ||
|
|
||
| # Update the cache | ||
| if len(self.key_cache) <= layer_idx: | ||
| # There may be skipped layers, fill them with empty lists | ||
| for _ in range(len(self.key_cache), layer_idx): | ||
| self.key_cache.append([]) | ||
| self.value_cache.append([]) | ||
| self.key_cache.append(key_states) | ||
| self.value_cache.append(value_states) | ||
| elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors | ||
| self.key_cache[layer_idx] = key_states | ||
| self.value_cache[layer_idx] = value_states | ||
| else: | ||
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
| if key_states is not None: | ||
| if len(self.key_cache) <= layer_idx: | ||
| # There may be skipped layers, fill them with empty lists | ||
| for _ in range(len(self.key_cache), layer_idx): | ||
| self.key_cache.append([]) | ||
| self.value_cache.append([]) | ||
| self.key_cache.append(key_states) | ||
| self.value_cache.append(value_states) | ||
| elif ( | ||
| len(self.key_cache[layer_idx]) == 0 | ||
| ): # fills previously skipped layers; checking for tensor causes errors | ||
| self.key_cache[layer_idx] = key_states | ||
| self.value_cache[layer_idx] = value_states | ||
| else: | ||
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
|
Comment on lines
+450
to
+451
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that is correct! I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :) |
||
|
|
||
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -670,6 +670,61 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F | |||||
| return | ||||||
|
|
||||||
|
|
||||||
| class EarlyExitCandidateGenerator(AssistedCandidateGenerator): | ||||||
| """ | ||||||
| `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates | ||||||
| candidates through the use of **the model itself**, exiting early. Can only be used with models that support early | ||||||
| exit. | ||||||
|
||||||
| exit. | |
| exit, e.g., `facebook/layerskip-llama3.2-1B` or any of the models listed in this [collection](https://huggingface.co/collections/facebook/layerskip-666b25c50c8ae90e1965727a). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a single model would be enough for me, a collection could give the impression that we are maintaining a list of compatible models there, which is not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(added a single model :) )
gante marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, I recently also added stopping_criteria as well to support integration with Eleuther LM Eval Harness:
facebookresearch/LayerSkip@e38784d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can ignore my comment about supporting StoppingCriteria. I checked out the PR and integrated with LM Eval Harness and found out that we don't need it.
I think I needed it in my custom implementation, but the native HF implementation doesn't.
gante marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -803,6 +803,7 @@ def __init__(self, config: CohereConfig): | |
| super().__init__(config) | ||
| self.padding_idx = config.pad_token_id | ||
| self.vocab_size = config.vocab_size | ||
| self.num_hidden_layers = config.num_hidden_layers | ||
|
|
||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | ||
| self.layers = nn.ModuleList( | ||
|
|
@@ -890,7 +891,7 @@ def forward( | |
| all_self_attns = () if output_attentions else None | ||
| next_decoder_cache = None | ||
|
|
||
| for decoder_layer in self.layers: | ||
| for decoder_layer in self.layers[: self.num_hidden_layers]: | ||
|
||
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) | ||
|
|
||
|
|
||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind! It's annoying to have to monkey patch all models, but fine in this case as it is strictly equivalent.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done -- we now use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nesting was not fully right -- normal "speculative decoding" examples were under "Universal Assisted Decoding". Moved a few things around)