Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/asr/conf/marblenet/marblenet_3x2x64_20ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
max_gain_dbfs: 10.0
noise:
prob: 0.6
manifest_path: /manifests/vad_noise/freesound_nonspeech_train_FL200.json
manifest_path: ???
min_snr_db: 0
max_snr_db: 20
max_gain_db: 300.0
Expand All @@ -51,15 +51,15 @@ model:
pin_memory: true
val_loss_idx: 0

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
labels: ${model.labels}
batch_size: 128
shuffle: False
num_workers: 8
pin_memory: true
test_loss_idx: 0
# test_ds:
# manifest_filepath: null
# sample_rate: ${model.sample_rate}
# labels: ${model.labels}
# batch_size: 128
# shuffle: False
# num_workers: 8
# pin_memory: true
# test_loss_idx: 0

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
--config-path=<path to dir of configs e.g. "../conf/marblenet">
--config-name=<name of config without .yaml e.g. "marblenet_3x2x64_20ms"> \
model.train_ds.manifest_filepath="<path to train manifest>" \
model.train_ds.augmentor.noise.manifest_path="<path to noise manifest>" \
model.validation_ds.manifest_filepath=["<path to val manifest>","<path to test manifest>"] \
trainer.devices=2 \
trainer.accelerator="gpu" \
Expand Down
56 changes: 51 additions & 5 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,19 +1041,27 @@ def _update_decoder_config(self, labels, cfg):
OmegaConf.set_struct(cfg, True)


class EncDecFrameClassificationModel(EncDecClassificationModel):
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"outputs": NeuralType(('B', 'T', 'C'), LogitsType())}
class EncDecFrameClassificationModel(_EncDecBaseModel):
"""
EncDecFrameClassificationModel is a model that performs classification on each frame of the input audio.
The default config (i.e., marblenet_3x2x64_20ms.yaml) outputs 20ms frames.
"""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.num_classes = len(cfg.labels)
self.eval_loop_cnt = 0
self.ratio_threshold = cfg.get('ratio_threshold', 0.2)
if cfg.get("is_regression_task", False):
raise ValueError("EndDecClassificationModel requires the flag is_regression_task to be set as false")

super().__init__(cfg=cfg, trainer=trainer)
self.decoder.output_types = self.output_types
self.decoder.output_types_for_export = self.output_types

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"outputs": NeuralType(('B', 'T', 'C'), LogitsType())}

@classmethod
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
results = []
Expand All @@ -1065,6 +1073,32 @@ def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
results.append(model)
return results

def _setup_preprocessor(self):
return EncDecClassificationModel.from_config_dict(self._cfg.preprocessor)

def _setup_encoder(self):
return EncDecClassificationModel.from_config_dict(self._cfg.encoder)

def _setup_decoder(self):
return EncDecClassificationModel.from_config_dict(self._cfg.decoder)

def _update_decoder_config(self, labels, cfg):
"""
Update the number of classes in the decoder based on labels provided.

Args:
labels: The current labels of the model
cfg: The config of the decoder which will be updated.
"""
OmegaConf.set_struct(cfg, False)

if 'params' in cfg:
cfg.params.num_classes = len(labels)
else:
cfg.num_classes = len(labels)

OmegaConf.set_struct(cfg, True)

def _setup_metrics(self):
self._accuracy = TopKClassificationAccuracy(dist_sync_on_step=True)
self._macro_accuracy = Accuracy(num_classes=self.num_classes, average='macro', task="multiclass")
Expand Down Expand Up @@ -1226,14 +1260,26 @@ def validation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
self._macro_accuracy.update(preds=metric_logits, target=metric_labels)
stats = self._macro_accuracy._final_state()

return {
output = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro': acc,
f'{tag}_acc_stats': stats,
}

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
return output

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
val_loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x[f'{tag}_correct_counts'] for x in outputs]).sum(axis=0)
Expand Down
70 changes: 69 additions & 1 deletion tests/collections/asr/test_asr_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
# limitations under the License.

import copy
import json
import os

import tempfile

import lightning.pytorch as pl
import numpy as np
import pytest
import soundfile as sf
import torch
from omegaconf import DictConfig, ListConfig

Expand Down Expand Up @@ -104,12 +110,20 @@ def frame_classification_model():
},
}

optim = {
'name': 'sgd',
'lr': 0.01,
'weight_decay': 0.001,
'momentum': 0.9,
}

modelConfig = DictConfig(
{
'preprocessor': DictConfig(preprocessor),
'encoder': DictConfig(encoder),
'decoder': DictConfig(decoder),
'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(5)]),
'optim': DictConfig(optim),
'labels': ListConfig(["0", "1"]),
}
)
model = EncDecFrameClassificationModel(cfg=modelConfig)
Expand Down Expand Up @@ -320,3 +334,57 @@ def test_EncDecClassificationDatasetConfig_for_AudioToMultiSpeechLabelDataset(se
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None

@pytest.mark.unit
def test_frame_classification_model(self, frame_classification_model: EncDecFrameClassificationModel):
with tempfile.TemporaryDirectory() as temp_dir:
# generate random audio
audio = np.random.randn(16000 * 1)
# save the audio
audio_path = os.path.join(temp_dir, "audio.wav")
sf.write(audio_path, audio, 16000)

dummy_labels = "0 0 0 0 1 1 1 1 0 0 0 0"

dummy_sample = {
"audio_filepath": audio_path,
"offset": 0.0,
"duration": 1.0,
"label": dummy_labels,
}

# create a manifest file
manifest_path = os.path.join(temp_dir, "dummy_manifest.json")
with open(manifest_path, "w") as f:
for i in range(4):
f.write(json.dumps(dummy_sample) + "\n")

dataloader_cfg = {
"batch_size": 2,
"manifest_filepath": manifest_path,
"sample_rate": 16000,
"num_workers": 0,
"shuffle": False,
"labels": ["0", "1"],
}

trainer_cfg = {
"max_epochs": 1,
"devices": 1,
"accelerator": "auto",
}

optim = {
'name': 'sgd',
'lr': 0.01,
'weight_decay': 0.001,
'momentum': 0.9,
}

trainer = pl.Trainer(**trainer_cfg)
frame_classification_model.set_trainer(trainer)
frame_classification_model.setup_optimization(DictConfig(optim))
frame_classification_model.setup_training_data(dataloader_cfg)
frame_classification_model.setup_validation_data(dataloader_cfg)

trainer.fit(frame_classification_model)
Loading