Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/asr/asr_adapters/train_asr_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def update_model_cfg(original_cfg, new_cfg):
return new_cfg


def add_global_adapter_cfg(model, global_adapter_cfg):
def add_global_adapter_cfg(model: ASRModel, global_adapter_cfg):
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation for parameter global_adapter_cfg is missing. Based on the function body (lines 131-135), this parameter can be a dataclass, DictConfig, or dict. The appropriate type annotation would be Union[DictConfig, dict] since the function checks for these types explicitly. This would require ensuring Union is imported from the typing module.

Copilot uses AI. Check for mistakes.
# Convert to DictConfig from dict or Dataclass
if is_dataclass(global_adapter_cfg):
global_adapter_cfg = OmegaConf.structured(global_adapter_cfg)
Expand Down
10 changes: 5 additions & 5 deletions examples/asr/speech_to_text_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"""
import time
import lightning.pytorch as pl
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.models import ASRModel
from nemo.core.config import hydra_runner
Expand All @@ -65,7 +65,7 @@
from nemo.utils.trainer_utils import resolve_trainer_cfg


def get_base_model(trainer, cfg):
def get_base_model(trainer: pl.Trainer, cfg: DictConfig) -> ASRModel:
"""
Returns the base model to be fine-tuned.
Currently supports two types of initializations:
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_base_model(trainer, cfg):
return asr_model


def check_vocabulary(asr_model, cfg):
def check_vocabulary(asr_model: ASRModel, cfg: DictConfig) -> ASRModel:
"""
Checks if the decoder and vocabulary of the model needs to be updated.
If either of them needs to be updated, it updates them and returns the updated model.
Expand All @@ -139,7 +139,7 @@ def check_vocabulary(asr_model, cfg):
return asr_model


def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
def update_tokenizer(asr_model: ASRModel, tokenizer_dir, tokenizer_type) -> ASRModel:
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotations for parameters tokenizer_dir and tokenizer_type are missing. Based on the change_vocabulary method signature in the ASRModel classes (e.g., rnnt_bpe_models.py line 340-344), these should be typed as:

  • tokenizer_dir: Union[str, DictConfig] (can be a directory path string or a DictConfig for 'agg' tokenizer type)
  • tokenizer_type: str

This would also require adding Union to the imports from typing module.

Copilot uses AI. Check for mistakes.
"""
Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size
of the new tokenizer differs from that of the loaded model.
Expand Down Expand Up @@ -173,7 +173,7 @@ def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
return asr_model


def setup_dataloaders(asr_model, cfg):
def setup_dataloaders(asr_model: ASRModel, cfg: DictConfig) -> ASRModel:
"""
Sets up the training, validation and test dataloaders for the model.
Args:
Expand Down
Loading