From 25c3dac6937c9d4d6cdf94b01def39f1845fc1cc Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Fri, 24 Apr 2020 22:41:57 +0545 Subject: [PATCH 1/8] Add Type Hints to modeling_utils.py Closes #3911 Add Type Hints to methods in `modeling_utils.py` Note: The coverage isn't 100%. Mostly skipped internal methods. --- src/transformers/modeling_utils.py | 64 +++++++++++++++--------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f0df0c1ee51f..80d68b3cb1c0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import logging import os -from typing import Callable, Tuple +from typing import Callable, Tuple, Sequence import torch from torch import Tensor, device, dtype, nn @@ -131,7 +131,7 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 return encoder_extended_attention_mask - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device): + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device) -> Tensor: """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. Arguments: @@ -175,7 +175,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask(self, head_mask, num_hidden_layers): + def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int) -> Tensor: """ # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -267,7 +267,7 @@ def get_input_embeddings(self): else: raise NotImplementedError - def set_input_embeddings(self, value): + def set_input_embeddings(self, value: nn.Module): """ Set model's input embeddings @@ -319,7 +319,7 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings - def resize_token_embeddings(self, new_num_tokens=None): + def resize_token_embeddings(self, new_num_tokens: int = None): """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. @@ -352,18 +352,20 @@ def _resize_token_embeddings(self, new_num_tokens): self.set_input_embeddings(new_embeddings) return self.get_input_embeddings() - def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): + def _get_resized_embeddings(self, old_embeddings: torch.nn.Embedding, new_num_tokens: int = None) -> torch.nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end Args: + old_embeddings: ``torch.nn.Embedding`` + Old embeddings to be resized. new_num_tokens: (`optional`) int New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end If not provided or None: return the provided token Embedding Module. - Return: ``torch.nn.Embeddings`` + Return: ``torch.nn.Embedding`` Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None """ if new_num_tokens is None: @@ -398,7 +400,7 @@ def init_weights(self): # Tie weights if needed self.tie_weights() - def prune_heads(self, heads_to_prune): + def prune_heads(self, heads_to_prune: dict): """ Prunes heads of the base model. Arguments: @@ -763,27 +765,27 @@ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output @torch.no_grad() def generate( self, - input_ids=None, - max_length=None, - min_length=None, - do_sample=None, - early_stopping=None, - num_beams=None, - temperature=None, - top_k=None, - top_p=None, - repetition_penalty=None, - bad_words_ids=None, - bos_token_id=None, - pad_token_id=None, - eos_token_id=None, - length_penalty=None, - no_repeat_ngram_size=None, - num_return_sequences=None, - attention_mask=None, - decoder_start_token_id=None, - use_cache=None, - ): + input_ids: torch.LongTensor = None, + max_length: int = None, + min_length: int = None, + do_sample: bool = None, + early_stopping: bool = None, + num_beams: int = None, + temperature: float = None, + top_k: int = None, + top_p: float = None, + repetition_penalty: float = None, + bad_words_ids: Sequence = None, + bos_token_id: int = None, + pad_token_id: int = None, + eos_token_id: int = None, + length_penalty: float = None, + no_repeat_ngram_size: int = None, + num_return_sequences: int = None, + attention_mask: torch.LongTensor = None, + decoder_start_token_id: int = None, + use_cache: bool = None, + ) -> torch.LongTensor: r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Adapted in part from `Facebook's XLM beam search code`_. @@ -1563,7 +1565,7 @@ def _get_generated_ngrams(hypo_idx): return banned_tokens -def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): +def calc_banned_bad_words_ids(prev_input_ids: Sequence, bad_words_ids: Sequence) -> Sequence: banned_tokens = [] def _tokens_match(prev_tokens, tokens): @@ -1599,7 +1601,7 @@ def _tokens_match(prev_tokens, tokens): return banned_tokens -def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): +def top_k_top_p_filtering(logits: Sequence, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1) -> Sequence: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) From 179270ca876fab0270f9e8c0a253e1ce0d932380 Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Fri, 24 Apr 2020 22:57:00 +0545 Subject: [PATCH 2/8] Reformat according to `black` and `isort` --- src/transformers/modeling_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 80d68b3cb1c0..3552e92e3cf8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import logging import os -from typing import Callable, Tuple, Sequence +from typing import Callable, Sequence, Tuple import torch from torch import Tensor, device, dtype, nn @@ -352,7 +352,9 @@ def _resize_token_embeddings(self, new_num_tokens): self.set_input_embeddings(new_embeddings) return self.get_input_embeddings() - def _get_resized_embeddings(self, old_embeddings: torch.nn.Embedding, new_num_tokens: int = None) -> torch.nn.Embedding: + def _get_resized_embeddings( + self, old_embeddings: torch.nn.Embedding, new_num_tokens: int = None + ) -> torch.nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end @@ -1601,7 +1603,13 @@ def _tokens_match(prev_tokens, tokens): return banned_tokens -def top_k_top_p_filtering(logits: Sequence, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1) -> Sequence: +def top_k_top_p_filtering( + logits: Sequence, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> Sequence: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) From c394156eda2b794617feb998fb74c49b10a34d77 Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Sat, 9 May 2020 08:35:52 +0545 Subject: [PATCH 3/8] Use typing.Iterable instead of Sequence --- src/transformers/modeling_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3552e92e3cf8..fe82a5f11494 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import logging import os -from typing import Callable, Sequence, Tuple +from typing import Callable, Iterable, Tuple import torch from torch import Tensor, device, dtype, nn @@ -777,7 +777,7 @@ def generate( top_k: int = None, top_p: float = None, repetition_penalty: float = None, - bad_words_ids: Sequence = None, + bad_words_ids: Iterable = None, bos_token_id: int = None, pad_token_id: int = None, eos_token_id: int = None, @@ -1567,7 +1567,7 @@ def _get_generated_ngrams(hypo_idx): return banned_tokens -def calc_banned_bad_words_ids(prev_input_ids: Sequence, bad_words_ids: Sequence) -> Sequence: +def calc_banned_bad_words_ids(prev_input_ids: Iterable, bad_words_ids: Iterable) -> Iterable: banned_tokens = [] def _tokens_match(prev_tokens, tokens): @@ -1604,12 +1604,12 @@ def _tokens_match(prev_tokens, tokens): def top_k_top_p_filtering( - logits: Sequence, + logits: Iterable, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, -) -> Sequence: +) -> Iterable: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) From 62308000eff7d515cbae4fb4a9a263de44e576cd Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Sat, 9 May 2020 13:01:02 +0545 Subject: [PATCH 4/8] Parameterize Iterable by its generic type --- src/transformers/modeling_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fe82a5f11494..bf9d57aa8718 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import logging import os -from typing import Callable, Iterable, Tuple +from typing import Callable, Iterable, Tuple, Dict import torch from torch import Tensor, device, dtype, nn @@ -402,7 +402,7 @@ def init_weights(self): # Tie weights if needed self.tie_weights() - def prune_heads(self, heads_to_prune: dict): + def prune_heads(self, heads_to_prune: Dict): """ Prunes heads of the base model. Arguments: @@ -777,7 +777,7 @@ def generate( top_k: int = None, top_p: float = None, repetition_penalty: float = None, - bad_words_ids: Iterable = None, + bad_words_ids: Iterable[int] = None, bos_token_id: int = None, pad_token_id: int = None, eos_token_id: int = None, @@ -1567,7 +1567,7 @@ def _get_generated_ngrams(hypo_idx): return banned_tokens -def calc_banned_bad_words_ids(prev_input_ids: Iterable, bad_words_ids: Iterable) -> Iterable: +def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: banned_tokens = [] def _tokens_match(prev_tokens, tokens): @@ -1604,12 +1604,12 @@ def _tokens_match(prev_tokens, tokens): def top_k_top_p_filtering( - logits: Iterable, + logits: Tensor, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, -) -> Iterable: +) -> Tensor: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) From 37942ca55654f0406023864fba1b0753aa5e64bb Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Sat, 9 May 2020 13:31:11 +0545 Subject: [PATCH 5/8] Use typing.Optional when None is the default value --- src/transformers/modeling_utils.py | 46 +++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bf9d57aa8718..761a22aad1a5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import logging import os -from typing import Callable, Iterable, Tuple, Dict +from typing import Callable, Dict, Iterable, Optional, Tuple import torch from torch import Tensor, device, dtype, nn @@ -319,7 +319,7 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings - def resize_token_embeddings(self, new_num_tokens: int = None): + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None): """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. @@ -353,7 +353,7 @@ def _resize_token_embeddings(self, new_num_tokens): return self.get_input_embeddings() def _get_resized_embeddings( - self, old_embeddings: torch.nn.Embedding, new_num_tokens: int = None + self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None ) -> torch.nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end @@ -767,26 +767,26 @@ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output @torch.no_grad() def generate( self, - input_ids: torch.LongTensor = None, - max_length: int = None, - min_length: int = None, - do_sample: bool = None, - early_stopping: bool = None, - num_beams: int = None, - temperature: float = None, - top_k: int = None, - top_p: float = None, - repetition_penalty: float = None, - bad_words_ids: Iterable[int] = None, - bos_token_id: int = None, - pad_token_id: int = None, - eos_token_id: int = None, - length_penalty: float = None, - no_repeat_ngram_size: int = None, - num_return_sequences: int = None, - attention_mask: torch.LongTensor = None, - decoder_start_token_id: int = None, - use_cache: bool = None, + input_ids: Optional[torch.LongTensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, ) -> torch.LongTensor: r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. From 99848e8abb42833383ffbd1322342aa363a6f31d Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Sat, 9 May 2020 14:34:33 +0545 Subject: [PATCH 6/8] Adhere to style guideline --- src/transformers/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8ec263df8926..c94bbd2a69fc 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -175,7 +175,9 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: Optional[bool] = False) -> Tensor: + def get_head_mask( + self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: Optional[bool] = False + ) -> Tensor: """ # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head From a67d3a1b91e14f36cef05c401aae9492970dc3b0 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 22 May 2020 19:02:34 -0400 Subject: [PATCH 7/8] Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c94bbd2a69fc..147a0404e38b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -175,9 +175,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask( - self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: Optional[bool] = False - ) -> Tensor: + def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor: """ # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head From 599ecbccb56c886e5f57645580d001cf4a173477 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 22 May 2020 19:06:09 -0400 Subject: [PATCH 8/8] Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 147a0404e38b..17df8ac51565 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -131,7 +131,7 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 return encoder_extended_attention_mask - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device) -> Tensor: + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor: """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. Arguments: