-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Add Type Hints to modeling_utils.py Closes #3911 #3948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Type Hints to modeling_utils.py Closes #3911 #3948
Conversation
Add Type Hints to methods in `modeling_utils.py` Note: The coverage isn't 100%. Mostly skipped internal methods.
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, thanks!
|
I think this is good for merge, no? @julien-c |
src/transformers/modeling_utils.py
Outdated
| import logging | ||
| import os | ||
| from typing import Callable, Tuple | ||
| from typing import Callable, Sequence, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we usually use typing.Iterable, any reason to use Sequence here instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to Iterable
| raise NotImplementedError | ||
|
|
||
| def set_input_embeddings(self, value): | ||
| def set_input_embeddings(self, value: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not an nn.Layer or even a nn.Embedding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going through some tests, looks like it's expected to be either nn.Embedding or AdaptiveEmbedding (from modeling_transfo_xl.py). The docstring has nn.Module (as "A module mapping vocabulary to hidden states"). Maybe it is so to have it be general? Not sure.
src/transformers/modeling_utils.py
Outdated
| output_embeddings.out_features = input_embeddings.num_embeddings | ||
|
|
||
| def resize_token_embeddings(self, new_num_tokens=None): | ||
| def resize_token_embeddings(self, new_num_tokens: int = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional[int] = None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/transformers/modeling_utils.py
Outdated
| self.tie_weights() | ||
|
|
||
| def prune_heads(self, heads_to_prune): | ||
| def prune_heads(self, heads_to_prune: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def prune_heads(self, heads_to_prune: dict): | |
| def prune_heads(self, heads_to_prune: Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it.
src/transformers/modeling_utils.py
Outdated
| num_return_sequences: int = None, | ||
| attention_mask: torch.LongTensor = None, | ||
| decoder_start_token_id: int = None, | ||
| use_cache: bool = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of those should be Optional[...]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/transformers/modeling_utils.py
Outdated
|
|
||
|
|
||
| def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): | ||
| def calc_banned_bad_words_ids(prev_input_ids: Sequence, bad_words_ids: Sequence) -> Sequence: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parametrize by their generic type, i.e. Sequence[int]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Codecov Report
@@ Coverage Diff @@
## master #3948 +/- ##
==========================================
- Coverage 78.39% 78.38% -0.02%
==========================================
Files 120 120
Lines 19925 19925
==========================================
- Hits 15620 15618 -2
- Misses 4305 4307 +2
Continue to review full report at Codecov.
|
|
Thanks @bglearning! |
Add Type Hints to methods in
modeling_utils.pyNote: The coverage isn't 100%. Mostly skipped internal methods (and some I wasn't sure of).