-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add typing to speech_to_text_finetune.py #15326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,7 @@ | |
| """ | ||
| import time | ||
| import lightning.pytorch as pl | ||
| from omegaconf import OmegaConf | ||
| from omegaconf import DictConfig, OmegaConf | ||
|
|
||
| from nemo.collections.asr.models import ASRModel | ||
| from nemo.core.config import hydra_runner | ||
|
|
@@ -65,7 +65,7 @@ | |
| from nemo.utils.trainer_utils import resolve_trainer_cfg | ||
|
|
||
|
|
||
| def get_base_model(trainer, cfg): | ||
| def get_base_model(trainer: pl.Trainer, cfg: DictConfig) -> ASRModel: | ||
| """ | ||
| Returns the base model to be fine-tuned. | ||
| Currently supports two types of initializations: | ||
|
|
@@ -112,7 +112,7 @@ def get_base_model(trainer, cfg): | |
| return asr_model | ||
|
|
||
|
|
||
| def check_vocabulary(asr_model, cfg): | ||
| def check_vocabulary(asr_model: ASRModel, cfg: DictConfig) -> ASRModel: | ||
| """ | ||
| Checks if the decoder and vocabulary of the model needs to be updated. | ||
| If either of them needs to be updated, it updates them and returns the updated model. | ||
|
|
@@ -139,7 +139,7 @@ def check_vocabulary(asr_model, cfg): | |
| return asr_model | ||
|
|
||
|
|
||
| def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type): | ||
| def update_tokenizer(asr_model: ASRModel, tokenizer_dir, tokenizer_type) -> ASRModel: | ||
|
||
| """ | ||
| Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size | ||
| of the new tokenizer differs from that of the loaded model. | ||
|
|
@@ -173,7 +173,7 @@ def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type): | |
| return asr_model | ||
|
|
||
|
|
||
| def setup_dataloaders(asr_model, cfg): | ||
| def setup_dataloaders(asr_model: ASRModel, cfg: DictConfig) -> ASRModel: | ||
| """ | ||
| Sets up the training, validation and test dataloaders for the model. | ||
| Args: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation for parameter
global_adapter_cfgis missing. Based on the function body (lines 131-135), this parameter can be a dataclass, DictConfig, or dict. The appropriate type annotation would beUnion[DictConfig, dict]since the function checks for these types explicitly. This would require ensuringUnionis imported from the typing module.