Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/asr/conf/marblenet/marblenet_3x2x64_20ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
max_gain_dbfs: 10.0
noise:
prob: 0.6
manifest_path: /manifests/vad_noise/freesound_nonspeech_train_FL200.json
manifest_path: ???
min_snr_db: 0
max_snr_db: 20
max_gain_db: 300.0
Expand All @@ -51,15 +51,15 @@ model:
pin_memory: true
val_loss_idx: 0

test_ds:
manifest_filepath: null
sample_rate: ${model.sample_rate}
labels: ${model.labels}
batch_size: 128
shuffle: False
num_workers: 8
pin_memory: true
test_loss_idx: 0
# test_ds:
# manifest_filepath: null
# sample_rate: ${model.sample_rate}
# labels: ${model.labels}
# batch_size: 128
# shuffle: False
# num_workers: 8
# pin_memory: true
# test_loss_idx: 0

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
--config-path=<path to dir of configs e.g. "../conf/marblenet">
--config-name=<name of config without .yaml e.g. "marblenet_3x2x64_20ms"> \
model.train_ds.manifest_filepath="<path to train manifest>" \
model.train_ds.augmentor.noise.manifest_path="<path to noise manifest>" \
model.validation_ds.manifest_filepath=["<path to val manifest>","<path to test manifest>"] \
trainer.devices=2 \
trainer.accelerator="gpu" \
Expand Down
56 changes: 54 additions & 2 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,47 @@
OmegaConf.set_struct(cfg, True)


class EncDecFrameClassificationModel(EncDecClassificationModel):
class EncDecFrameClassificationModel(_EncDecBaseModel):

def __init__(self, cfg: DictConfig, trainer: Trainer = None):

if cfg.get("is_regression_task", False):
raise ValueError(f"EndDecClassificationModel requires the flag is_regression_task to be set as false")

super().__init__(cfg=cfg, trainer=trainer)

def _setup_preprocessor(self):
return EncDecClassificationModel.from_config_dict(self._cfg.preprocessor)

def _setup_encoder(self):
return EncDecClassificationModel.from_config_dict(self._cfg.encoder)

def _setup_decoder(self):
return EncDecClassificationModel.from_config_dict(self._cfg.decoder)

def _setup_loss(self):
return CrossEntropyLoss()

def _setup_metrics(self):
self._accuracy = TopKClassificationAccuracy(dist_sync_on_step=True)

def _update_decoder_config(self, labels, cfg):
"""
Update the number of classes in the decoder based on labels provided.
Args:
labels: The current labels of the model
cfg: The config of the decoder which will be updated.
"""
OmegaConf.set_struct(cfg, False)

if 'params' in cfg:
cfg.params.num_classes = len(labels)
else:
cfg.num_classes = len(labels)

OmegaConf.set_struct(cfg, True)

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"outputs": NeuralType(('B', 'T', 'C'), LogitsType())}
Expand Down Expand Up @@ -1226,14 +1266,26 @@
self._macro_accuracy.update(preds=metric_logits, target=metric_labels)
stats = self._macro_accuracy._final_state()

return {
output = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro': acc,
f'{tag}_acc_stats': stats,
}

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
return output

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
val_loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x[f'{tag}_correct_counts'] for x in outputs]).sum(axis=0)
Expand Down
Loading