Skip to content

Commit 43bfbf2

Browse files
remove confidence ensemble models (#14343)
* rm confidence ensemple models Signed-off-by: lilithgrigoryan <[email protected]> * Apply isort and black reformatting Signed-off-by: lilithgrigoryan <[email protected]> * clean up Signed-off-by: lilithgrigoryan <[email protected]> --------- Signed-off-by: lilithgrigoryan <[email protected]> Signed-off-by: lilithgrigoryan <[email protected]> Co-authored-by: lilithgrigoryan <[email protected]>
1 parent fd3ee74 commit 43bfbf2

File tree

9 files changed

+1
-1241
lines changed

9 files changed

+1
-1241
lines changed

docs/source/asr/api.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ Model Classes
3939
:show-inheritance:
4040
:members: from_asr_config, from_pretrained_models, save_asr_model_to, setup_training_data
4141

42-
.. _confidence-ensembles-api:
43-
44-
.. autoclass:: nemo.collections.asr.models.confidence_ensemble.ConfidenceEnsembleModel
45-
:show-inheritance:
46-
:members: transcribe
47-
4842
.. _asr-api-modules:
4943

5044
Modules

docs/source/asr/models.rst

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -309,38 +309,6 @@ For the detailed information see:
309309
* :ref:`Text-only dataset <Hybrid-ASR-TTS_model__Text-Only-Data>` preparation
310310
* :ref:`Configs and training <Hybrid-ASR-TTS_model__Config>`
311311

312-
313-
.. _Confidence-Ensembles:
314-
315-
Confidence-based Ensembles
316-
--------------------------
317-
318-
Confidence-based ensemble is a simple way to combine multiple models into a single system by only retaining the
319-
output of the most confident model. Below is a schematic illustration of how such ensembles work.
320-
321-
.. image:: images/conf-ensembles-overview.png
322-
:align: center
323-
:alt: confidence-based ensembles
324-
:scale: 50%
325-
326-
For more details about this model, see the `paper <https://arxiv.org/abs/2306.15824>`_
327-
or read our `tutorial <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/Confidence_Ensembles.ipynb>`_.
328-
329-
NeMo support Confidence-based Ensembles through the
330-
:ref:`nemo.collections.asr.models.confidence_ensemble.ConfidenceEnsembleModel <confidence-ensembles-api>` class.
331-
332-
A typical workflow to create and use the ensemble is like this
333-
334-
1. Run `scripts/confidence_ensembles/build_ensemble.py <https://github.com/NVIDIA/NeMo/blob/main/scripts/confidence_ensembles/build_ensemble.py>`_
335-
script to create ensemble from existing models. See the documentation inside the script for usage examples
336-
and description of all the supported functionality.
337-
2. The script outputs a checkpoint that combines all the models in an ensemble. It can be directly used to transcribe
338-
speech by calling ``.trascribe()`` method or using
339-
`examples/asr/transcribe_speech.py <https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py>`_.
340-
341-
Note that the ensemble cannot be modified after construction (e.g. it does not support finetuning) and only
342-
transcribe functionality is supported (e.g., ``.forward()`` is not properly defined).
343-
344312
.. _Jasper_model:
345313

346314
Jasper

docs/source/starthere/tutorials.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,6 @@ Tutorial Overview
152152
* - ASR
153153
- ASR Confidence Estimation
154154
- `ASR Confidence Estimation <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_Confidence_Estimation.ipynb>`_
155-
* - ASR
156-
- Confidence-based Ensembles
157-
- `Confidence-based Ensembles <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/Confidence_Ensembles.ipynb>`_
158155

159156
.. list-table:: **Text-to-Speech (TTS) Tutorials**
160157
:widths: 15 35 50

nemo/collections/asr/models/confidence_ensemble.py

Lines changed: 0 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,23 @@
1717
import pickle
1818
import warnings
1919
from dataclasses import dataclass
20-
from typing import Dict, List, Optional, Union
2120

2221
try:
2322
from joblib.numpy_pickle_utils import _read_fileobject as _validate_joblib_file
2423
except ImportError:
2524
from joblib.numpy_pickle_utils import _validate_fileobject_and_memmap as _validate_joblib_file
26-
import numpy as np
2725
import torch
28-
from lightning.pytorch import Trainer
29-
from omegaconf import DictConfig, open_dict
3026
from sklearn.linear_model import LogisticRegression
3127
from sklearn.pipeline import Pipeline
3228
from sklearn.preprocessing import StandardScaler
3329

34-
from nemo.collections.asr.models.asr_model import ASRModel
35-
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
36-
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
3730
from nemo.collections.asr.parts.utils.asr_confidence_utils import (
3831
ConfidenceConfig,
3932
ConfidenceMethodConfig,
4033
get_confidence_aggregation_bank,
4134
get_confidence_measure_bank,
4235
)
4336
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
44-
from nemo.core.classes import ModelPT
45-
from nemo.utils import model_utils
46-
from nemo.utils.decorators import deprecated
4737

4838

4939
# frozen is required to allow hashing of this class and use it
@@ -241,191 +231,3 @@ class SecurityError(Exception):
241231
"""Custom exception for security-related errors."""
242232

243233
pass
244-
245-
246-
@deprecated(version='v2.1.0')
247-
class ConfidenceEnsembleModel(ModelPT):
248-
"""Implementation of the confidence ensemble model.
249-
250-
See https://arxiv.org/abs/2306.15824 for details.
251-
252-
.. note::
253-
Currently this class only support `transcribe` method as it requires
254-
full-utterance confidence scores to operate.
255-
"""
256-
257-
def __init__(
258-
self,
259-
cfg: DictConfig,
260-
trainer: 'Trainer' = None,
261-
):
262-
super().__init__(cfg=cfg, trainer=trainer)
263-
264-
# either we load all models from ``load_models`` cfg parameter
265-
# or all of them are specified in the config as modelX alongside the num_models key
266-
#
267-
# ideally, we'd like to directly store all models in a list, but that
268-
# is not currently supported by the submodule logic
269-
# so to access all the models, we do something like
270-
#
271-
# for model_idx in range(self.num_models):
272-
# model = getattr(self, f"model{model_idx}")
273-
274-
if 'num_models' in self.cfg:
275-
self.num_models = self.cfg.num_models
276-
for idx in range(self.num_models):
277-
cfg_field = f"model{idx}"
278-
model_cfg = self.cfg[cfg_field]
279-
model_class = model_utils.import_class_by_path(model_cfg['target'])
280-
self.register_nemo_submodule(
281-
name=cfg_field,
282-
config_field=cfg_field,
283-
model=model_class(model_cfg, trainer=trainer),
284-
)
285-
else:
286-
self.num_models = len(cfg.load_models)
287-
with open_dict(self.cfg):
288-
self.cfg.num_models = self.num_models
289-
for idx, model in enumerate(cfg.load_models):
290-
cfg_field = f"model{idx}"
291-
if model.endswith(".nemo"):
292-
self.register_nemo_submodule(
293-
name=cfg_field,
294-
config_field=cfg_field,
295-
model=ASRModel.restore_from(model, trainer=trainer, map_location="cpu"),
296-
)
297-
else:
298-
self.register_nemo_submodule(
299-
cfg_field,
300-
config_field=cfg_field,
301-
model=ASRModel.from_pretrained(model, map_location="cpu"),
302-
)
303-
304-
# registering model selection block - this is expected to be a joblib-saved
305-
# pretrained sklearn pipeline containing standardization + logistic regression
306-
# trained to predict "most-confident" model index from the confidence scores of all models
307-
model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block)
308-
try:
309-
self.model_selection_block = safe_joblib_load(model_selection_block_path)
310-
except SecurityError as e:
311-
raise RuntimeError(f"Security error loading model selection block: {str(e)}")
312-
except Exception as e:
313-
raise RuntimeError(f"Error loading model selection block: {str(e)}")
314-
315-
self.confidence_cfg = ConfidenceConfig(**self.cfg.confidence)
316-
317-
# making sure each model has correct temperature setting in the decoder strategy
318-
for model_idx in range(self.num_models):
319-
model = getattr(self, f"model{model_idx}")
320-
# for now we assume users are direclty responsible for matching
321-
# decoder type when building ensemble with inference type
322-
# TODO: add automatic checks for errors
323-
if isinstance(model, EncDecHybridRNNTCTCModel):
324-
self.update_decoding_parameters(model.cfg.decoding)
325-
model.change_decoding_strategy(model.cfg.decoding, decoder_type="rnnt")
326-
self.update_decoding_parameters(model.cfg.aux_ctc.decoding)
327-
model.change_decoding_strategy(model.cfg.aux_ctc.decoding, decoder_type="ctc")
328-
else:
329-
self.update_decoding_parameters(model.cfg.decoding)
330-
model.change_decoding_strategy(model.cfg.decoding)
331-
332-
def update_decoding_parameters(self, decoding_cfg: DictConfig):
333-
"""Updating temperature/preserve_alignment parameters of the config."""
334-
with open_dict(decoding_cfg):
335-
decoding_cfg.temperature = self.cfg.temperature
336-
decoding_cfg.preserve_alignments = True
337-
338-
def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
339-
"""Pass-through to the ensemble models.
340-
341-
Note that training is not actually supported for this class!
342-
"""
343-
for model_idx in range(self.num_models):
344-
getattr(self, f"model{model_idx}").setup_training_data(train_data_config)
345-
346-
def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
347-
"""Pass-through to the ensemble models."""
348-
for model_idx in range(self.num_models):
349-
getattr(self, f"model{model_idx}").setup_validation_data(val_data_config)
350-
351-
def change_attention_model(
352-
self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True
353-
):
354-
"""Pass-through to the ensemble models."""
355-
for model_idx in range(self.num_models):
356-
getattr(self, f"model{model_idx}").change_attention_model(
357-
self_attention_model, att_context_size, update_config
358-
)
359-
360-
def change_decoding_strategy(self, decoding_cfg: Optional[DictConfig] = None, decoder_type: str = None):
361-
"""Pass-through to the ensemble models.
362-
363-
The only change here is that we always require expected temperature
364-
to be set as well as ``decoding_cfg.preserve_alignments = True``
365-
"""
366-
self.update_decoding_parameters(decoding_cfg)
367-
for model_idx in range(self.num_models):
368-
model = getattr(self, f"model{model_idx}")
369-
if isinstance(model, EncDecHybridRNNTCTCModel):
370-
model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type)
371-
else:
372-
model.change_decoding_strategy(decoding_cfg)
373-
374-
@torch.no_grad()
375-
def transcribe(
376-
self,
377-
paths2audio_files: List[str],
378-
batch_size: int = 4,
379-
return_hypotheses: bool = False,
380-
num_workers: int = 0,
381-
channel_selector: Optional[ChannelSelectorType] = None,
382-
augmentor: DictConfig = None,
383-
verbose: bool = True,
384-
**kwargs, # any other model specific parameters are passed directly
385-
) -> List[str]:
386-
"""Confidence-ensemble transcribe method.
387-
388-
Consists of the following steps:
389-
390-
1. Run all models (TODO: in parallel)
391-
2. Compute confidence for each model
392-
3. Use logistic regression to pick the "most confident" model
393-
4. Return the output of that model
394-
"""
395-
confidences = []
396-
all_transcriptions = []
397-
# always requiring to return hypothesis
398-
# TODO: make sure to return text only if was False originally
399-
return_hypotheses = True
400-
for model_idx in range(self.num_models):
401-
model = getattr(self, f"model{model_idx}")
402-
transcriptions = model.transcribe(
403-
paths2audio_files=paths2audio_files,
404-
batch_size=batch_size,
405-
return_hypotheses=return_hypotheses,
406-
num_workers=num_workers,
407-
channel_selector=channel_selector,
408-
augmentor=augmentor,
409-
verbose=verbose,
410-
**kwargs,
411-
)
412-
if isinstance(transcriptions, tuple): # transducers return a tuple
413-
transcriptions = transcriptions[0]
414-
415-
model_confidences = []
416-
for transcription in transcriptions:
417-
model_confidences.append(compute_confidence(transcription, self.confidence_cfg))
418-
confidences.append(model_confidences)
419-
all_transcriptions.append(transcriptions)
420-
421-
# transposing with zip(*list)
422-
features = np.array(list(zip(*confidences)))
423-
model_indices = self.model_selection_block.predict(features)
424-
final_transcriptions = []
425-
for transcrption_idx in range(len(all_transcriptions[0])):
426-
final_transcriptions.append(all_transcriptions[model_indices[transcrption_idx]][transcrption_idx])
427-
428-
return final_transcriptions
429-
430-
def list_available_models(self):
431-
return []

0 commit comments

Comments
 (0)