Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/asr/asr_adapters/train_asr_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"""
import os
from dataclasses import is_dataclass
from typing import Union

import lightning.pytorch as pl
from omegaconf import DictConfig, OmegaConf, open_dict
Expand Down Expand Up @@ -126,7 +127,7 @@ def update_model_cfg(original_cfg, new_cfg):
return new_cfg


def add_global_adapter_cfg(model, global_adapter_cfg):
def add_global_adapter_cfg(model: ASRModel, global_adapter_cfg: Union[DictConfig, dict]):
# Convert to DictConfig from dict or Dataclass
if is_dataclass(global_adapter_cfg):
global_adapter_cfg = OmegaConf.structured(global_adapter_cfg)
Expand Down
12 changes: 7 additions & 5 deletions examples/asr/speech_to_text_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
"""
import time
from typing import Union

import lightning.pytorch as pl
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.models import ASRModel
from nemo.core.config import hydra_runner
Expand All @@ -65,7 +67,7 @@
from nemo.utils.trainer_utils import resolve_trainer_cfg


def get_base_model(trainer, cfg):
def get_base_model(trainer: pl.Trainer, cfg: DictConfig) -> ASRModel:
"""
Returns the base model to be fine-tuned.
Currently supports two types of initializations:
Expand Down Expand Up @@ -112,7 +114,7 @@ def get_base_model(trainer, cfg):
return asr_model


def check_vocabulary(asr_model, cfg):
def check_vocabulary(asr_model: ASRModel, cfg: DictConfig) -> ASRModel:
"""
Checks if the decoder and vocabulary of the model needs to be updated.
If either of them needs to be updated, it updates them and returns the updated model.
Expand All @@ -139,7 +141,7 @@ def check_vocabulary(asr_model, cfg):
return asr_model


def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
def update_tokenizer(asr_model: ASRModel, tokenizer_dir: Union[str, DictConfig], tokenizer_type: str) -> ASRModel:
"""
Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size
of the new tokenizer differs from that of the loaded model.
Expand Down Expand Up @@ -173,7 +175,7 @@ def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
return asr_model


def setup_dataloaders(asr_model, cfg):
def setup_dataloaders(asr_model: ASRModel, cfg: DictConfig) -> ASRModel:
"""
Sets up the training, validation and test dataloaders for the model.
Args:
Expand Down
Loading