Skip to content
72 changes: 41 additions & 31 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import logging
import os
from typing import Callable, Tuple
from typing import Callable, Dict, Iterable, Optional, Tuple

import torch
from torch import Tensor, device, dtype, nn
Expand Down Expand Up @@ -131,7 +131,7 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
return encoder_extended_attention_mask

def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device) -> Tensor:
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.

Arguments:
Expand Down Expand Up @@ -175,7 +175,7 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down Expand Up @@ -269,7 +269,7 @@ def get_input_embeddings(self):
else:
raise NotImplementedError

def set_input_embeddings(self, value):
def set_input_embeddings(self, value: nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not an nn.Layer or even a nn.Embedding?

Copy link
Contributor Author

@bglearning bglearning May 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going through some tests, looks like it's expected to be either nn.Embedding or AdaptiveEmbedding (from modeling_transfo_xl.py). The docstring has nn.Module (as "A module mapping vocabulary to hidden states"). Maybe it is so to have it be general? Not sure.

"""
Set model's input embeddings

Expand Down Expand Up @@ -321,7 +321,7 @@ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings

def resize_token_embeddings(self, new_num_tokens=None):
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

Expand Down Expand Up @@ -354,18 +354,22 @@ def _resize_token_embeddings(self, new_num_tokens):
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()

def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
def _get_resized_embeddings(
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding:
""" Build a resized Embedding Module from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end

Args:
old_embeddings: ``torch.nn.Embedding``
Old embeddings to be resized.
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``torch.nn.Embeddings``
Return: ``torch.nn.Embedding``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
if new_num_tokens is None:
Expand Down Expand Up @@ -400,7 +404,7 @@ def init_weights(self):
# Tie weights if needed
self.tie_weights()

def prune_heads(self, heads_to_prune):
def prune_heads(self, heads_to_prune: Dict):
""" Prunes heads of the base model.

Arguments:
Expand Down Expand Up @@ -768,28 +772,28 @@ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output
@torch.no_grad()
def generate(
self,
input_ids=None,
max_length=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_specific_kwargs
):
) -> torch.LongTensor:
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.

Adapted in part from `Facebook's XLM beam search code`_.
Expand Down Expand Up @@ -1571,7 +1575,7 @@ def _get_generated_ngrams(hypo_idx):
return banned_tokens


def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []

def _tokens_match(prev_tokens, tokens):
Expand Down Expand Up @@ -1607,7 +1611,13 @@ def _tokens_match(prev_tokens, tokens):
return banned_tokens


def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
def top_k_top_p_filtering(
logits: Tensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> Tensor:
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
Expand Down