-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Add Type Hints to modeling_utils.py Closes #3911 #3948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
25c3dac
179270c
c394156
6230800
37942ca
8fb3571
99848e8
a67d3a1
599ecbc
64b20c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| import inspect | ||
| import logging | ||
| import os | ||
| from typing import Callable, Tuple | ||
| from typing import Callable, Dict, Iterable, Optional, Tuple | ||
|
|
||
| import torch | ||
| from torch import Tensor, device, dtype, nn | ||
|
|
@@ -131,7 +131,7 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: | |
| encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 | ||
| return encoder_extended_attention_mask | ||
|
|
||
| def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device): | ||
| def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device) -> Tensor: | ||
| """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. | ||
|
|
||
| Arguments: | ||
|
|
@@ -175,7 +175,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple | |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||
| return extended_attention_mask | ||
|
|
||
| def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False): | ||
| def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor: | ||
| """ | ||
| # Prepare head mask if needed | ||
| # 1.0 in head_mask indicate we keep the head | ||
|
|
@@ -269,7 +269,7 @@ def get_input_embeddings(self): | |
| else: | ||
| raise NotImplementedError | ||
|
|
||
| def set_input_embeddings(self, value): | ||
| def set_input_embeddings(self, value: nn.Module): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this not an
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going through some tests, looks like it's expected to be either |
||
| """ | ||
| Set model's input embeddings | ||
|
|
||
|
|
@@ -321,7 +321,7 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings): | |
| if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): | ||
| output_embeddings.out_features = input_embeddings.num_embeddings | ||
|
|
||
| def resize_token_embeddings(self, new_num_tokens=None): | ||
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None): | ||
| """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. | ||
| Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. | ||
|
|
||
|
|
@@ -354,18 +354,22 @@ def _resize_token_embeddings(self, new_num_tokens): | |
| self.set_input_embeddings(new_embeddings) | ||
| return self.get_input_embeddings() | ||
|
|
||
| def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): | ||
| def _get_resized_embeddings( | ||
| self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None | ||
| ) -> torch.nn.Embedding: | ||
| """ Build a resized Embedding Module from a provided token Embedding Module. | ||
| Increasing the size will add newly initialized vectors at the end | ||
| Reducing the size will remove vectors from the end | ||
|
|
||
| Args: | ||
| old_embeddings: ``torch.nn.Embedding`` | ||
| Old embeddings to be resized. | ||
| new_num_tokens: (`optional`) int | ||
| New number of tokens in the embedding matrix. | ||
| Increasing the size will add newly initialized vectors at the end | ||
| Reducing the size will remove vectors from the end | ||
| If not provided or None: return the provided token Embedding Module. | ||
| Return: ``torch.nn.Embeddings`` | ||
| Return: ``torch.nn.Embedding`` | ||
| Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None | ||
| """ | ||
| if new_num_tokens is None: | ||
|
|
@@ -400,7 +404,7 @@ def init_weights(self): | |
| # Tie weights if needed | ||
| self.tie_weights() | ||
|
|
||
| def prune_heads(self, heads_to_prune): | ||
| def prune_heads(self, heads_to_prune: Dict): | ||
| """ Prunes heads of the base model. | ||
|
|
||
| Arguments: | ||
|
|
@@ -768,28 +772,28 @@ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output | |
| @torch.no_grad() | ||
| def generate( | ||
| self, | ||
| input_ids=None, | ||
| max_length=None, | ||
| min_length=None, | ||
| do_sample=None, | ||
| early_stopping=None, | ||
| num_beams=None, | ||
| temperature=None, | ||
| top_k=None, | ||
| top_p=None, | ||
| repetition_penalty=None, | ||
| bad_words_ids=None, | ||
| bos_token_id=None, | ||
| pad_token_id=None, | ||
| eos_token_id=None, | ||
| length_penalty=None, | ||
| no_repeat_ngram_size=None, | ||
| num_return_sequences=None, | ||
| attention_mask=None, | ||
| decoder_start_token_id=None, | ||
| use_cache=None, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| max_length: Optional[int] = None, | ||
| min_length: Optional[int] = None, | ||
| do_sample: Optional[bool] = None, | ||
| early_stopping: Optional[bool] = None, | ||
| num_beams: Optional[int] = None, | ||
| temperature: Optional[float] = None, | ||
| top_k: Optional[int] = None, | ||
| top_p: Optional[float] = None, | ||
| repetition_penalty: Optional[float] = None, | ||
| bad_words_ids: Optional[Iterable[int]] = None, | ||
| bos_token_id: Optional[int] = None, | ||
| pad_token_id: Optional[int] = None, | ||
| eos_token_id: Optional[int] = None, | ||
| length_penalty: Optional[float] = None, | ||
| no_repeat_ngram_size: Optional[int] = None, | ||
| num_return_sequences: Optional[int] = None, | ||
| attention_mask: Optional[torch.LongTensor] = None, | ||
| decoder_start_token_id: Optional[int] = None, | ||
| use_cache: Optional[bool] = None, | ||
| **model_specific_kwargs | ||
| ): | ||
| ) -> torch.LongTensor: | ||
| r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. | ||
|
|
||
| Adapted in part from `Facebook's XLM beam search code`_. | ||
|
|
@@ -1571,7 +1575,7 @@ def _get_generated_ngrams(hypo_idx): | |
| return banned_tokens | ||
|
|
||
|
|
||
| def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): | ||
| def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: | ||
| banned_tokens = [] | ||
|
|
||
| def _tokens_match(prev_tokens, tokens): | ||
|
|
@@ -1607,7 +1611,13 @@ def _tokens_match(prev_tokens, tokens): | |
| return banned_tokens | ||
|
|
||
|
|
||
| def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): | ||
| def top_k_top_p_filtering( | ||
| logits: Tensor, | ||
| top_k: int = 0, | ||
| top_p: float = 1.0, | ||
| filter_value: float = -float("Inf"), | ||
| min_tokens_to_keep: int = 1, | ||
| ) -> Tensor: | ||
| """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | ||
| Args: | ||
| logits: logits distribution shape (batch size, vocabulary size) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.