diff --git a/examples/asr/conf/marblenet/marblenet_3x2x64_20ms.yaml b/examples/asr/conf/marblenet/marblenet_3x2x64_20ms.yaml index 2c98c210eb0e..5329c9e4c3e6 100644 --- a/examples/asr/conf/marblenet/marblenet_3x2x64_20ms.yaml +++ b/examples/asr/conf/marblenet/marblenet_3x2x64_20ms.yaml @@ -36,7 +36,7 @@ model: max_gain_dbfs: 10.0 noise: prob: 0.6 - manifest_path: /manifests/vad_noise/freesound_nonspeech_train_FL200.json + manifest_path: ??? min_snr_db: 0 max_snr_db: 20 max_gain_db: 300.0 @@ -51,15 +51,15 @@ model: pin_memory: true val_loss_idx: 0 - test_ds: - manifest_filepath: null - sample_rate: ${model.sample_rate} - labels: ${model.labels} - batch_size: 128 - shuffle: False - num_workers: 8 - pin_memory: true - test_loss_idx: 0 + # test_ds: + # manifest_filepath: null + # sample_rate: ${model.sample_rate} + # labels: ${model.labels} + # batch_size: 128 + # shuffle: False + # num_workers: 8 + # pin_memory: true + # test_loss_idx: 0 preprocessor: _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor diff --git a/examples/asr/speech_classification/speech_to_frame_label.py b/examples/asr/speech_classification/speech_to_frame_label.py index 2a9206efab78..521dd6b1f025 100644 --- a/examples/asr/speech_classification/speech_to_frame_label.py +++ b/examples/asr/speech_classification/speech_to_frame_label.py @@ -22,6 +22,7 @@ --config-path= --config-name= \ model.train_ds.manifest_filepath="" \ + model.train_ds.augmentor.noise.manifest_path="" \ model.validation_ds.manifest_filepath=["",""] \ trainer.devices=2 \ trainer.accelerator="gpu" \ diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index 3e64bbdb6445..bedabaa3e88d 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -1041,19 +1041,27 @@ def _update_decoder_config(self, labels, cfg): OmegaConf.set_struct(cfg, True) -class EncDecFrameClassificationModel(EncDecClassificationModel): - @property - def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"outputs": NeuralType(('B', 'T', 'C'), LogitsType())} +class EncDecFrameClassificationModel(_EncDecBaseModel): + """ + EncDecFrameClassificationModel is a model that performs classification on each frame of the input audio. + The default config (i.e., marblenet_3x2x64_20ms.yaml) outputs 20ms frames. + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.num_classes = len(cfg.labels) self.eval_loop_cnt = 0 self.ratio_threshold = cfg.get('ratio_threshold', 0.2) + if cfg.get("is_regression_task", False): + raise ValueError("EndDecClassificationModel requires the flag is_regression_task to be set as false") + super().__init__(cfg=cfg, trainer=trainer) self.decoder.output_types = self.output_types self.decoder.output_types_for_export = self.output_types + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return {"outputs": NeuralType(('B', 'T', 'C'), LogitsType())} + @classmethod def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: results = [] @@ -1065,6 +1073,32 @@ def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: results.append(model) return results + def _setup_preprocessor(self): + return EncDecClassificationModel.from_config_dict(self._cfg.preprocessor) + + def _setup_encoder(self): + return EncDecClassificationModel.from_config_dict(self._cfg.encoder) + + def _setup_decoder(self): + return EncDecClassificationModel.from_config_dict(self._cfg.decoder) + + def _update_decoder_config(self, labels, cfg): + """ + Update the number of classes in the decoder based on labels provided. + + Args: + labels: The current labels of the model + cfg: The config of the decoder which will be updated. + """ + OmegaConf.set_struct(cfg, False) + + if 'params' in cfg: + cfg.params.num_classes = len(labels) + else: + cfg.num_classes = len(labels) + + OmegaConf.set_struct(cfg, True) + def _setup_metrics(self): self._accuracy = TopKClassificationAccuracy(dist_sync_on_step=True) self._macro_accuracy = Accuracy(num_classes=self.num_classes, average='macro', task="multiclass") @@ -1226,7 +1260,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = self._macro_accuracy.update(preds=metric_logits, target=metric_labels) stats = self._macro_accuracy._final_state() - return { + output = { f'{tag}_loss': loss_value, f'{tag}_correct_counts': correct_counts, f'{tag}_total_counts': total_counts, @@ -1234,6 +1268,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = f'{tag}_acc_stats': stats, } + if tag == 'val': + if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(output) + else: + self.validation_step_outputs.append(output) + else: + if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(output) + else: + self.test_step_outputs.append(output) + return output + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): val_loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean() correct_counts = torch.stack([x[f'{tag}_correct_counts'] for x in outputs]).sum(axis=0) diff --git a/tests/collections/asr/test_asr_classification_model.py b/tests/collections/asr/test_asr_classification_model.py index 87ab3d73c1ea..daef542aecc4 100644 --- a/tests/collections/asr/test_asr_classification_model.py +++ b/tests/collections/asr/test_asr_classification_model.py @@ -13,9 +13,15 @@ # limitations under the License. import copy +import json import os +import tempfile + +import lightning.pytorch as pl +import numpy as np import pytest +import soundfile as sf import torch from omegaconf import DictConfig, ListConfig @@ -104,12 +110,20 @@ def frame_classification_model(): }, } + optim = { + 'name': 'sgd', + 'lr': 0.01, + 'weight_decay': 0.001, + 'momentum': 0.9, + } + modelConfig = DictConfig( { 'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder), - 'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(5)]), + 'optim': DictConfig(optim), + 'labels': ListConfig(["0", "1"]), } ) model = EncDecFrameClassificationModel(cfg=modelConfig) @@ -320,3 +334,57 @@ def test_EncDecClassificationDatasetConfig_for_AudioToMultiSpeechLabelDataset(se assert signatures_match assert cls_subset is None assert dataclass_subset is None + + @pytest.mark.unit + def test_frame_classification_model(self, frame_classification_model: EncDecFrameClassificationModel): + with tempfile.TemporaryDirectory() as temp_dir: + # generate random audio + audio = np.random.randn(16000 * 1) + # save the audio + audio_path = os.path.join(temp_dir, "audio.wav") + sf.write(audio_path, audio, 16000) + + dummy_labels = "0 0 0 0 1 1 1 1 0 0 0 0" + + dummy_sample = { + "audio_filepath": audio_path, + "offset": 0.0, + "duration": 1.0, + "label": dummy_labels, + } + + # create a manifest file + manifest_path = os.path.join(temp_dir, "dummy_manifest.json") + with open(manifest_path, "w") as f: + for i in range(4): + f.write(json.dumps(dummy_sample) + "\n") + + dataloader_cfg = { + "batch_size": 2, + "manifest_filepath": manifest_path, + "sample_rate": 16000, + "num_workers": 0, + "shuffle": False, + "labels": ["0", "1"], + } + + trainer_cfg = { + "max_epochs": 1, + "devices": 1, + "accelerator": "auto", + } + + optim = { + 'name': 'sgd', + 'lr': 0.01, + 'weight_decay': 0.001, + 'momentum': 0.9, + } + + trainer = pl.Trainer(**trainer_cfg) + frame_classification_model.set_trainer(trainer) + frame_classification_model.setup_optimization(DictConfig(optim)) + frame_classification_model.setup_training_data(dataloader_cfg) + frame_classification_model.setup_validation_data(dataloader_cfg) + + trainer.fit(frame_classification_model)