|
17 | 17 | import pickle |
18 | 18 | import warnings |
19 | 19 | from dataclasses import dataclass |
20 | | -from typing import Dict, List, Optional, Union |
21 | 20 |
|
22 | 21 | try: |
23 | 22 | from joblib.numpy_pickle_utils import _read_fileobject as _validate_joblib_file |
24 | 23 | except ImportError: |
25 | 24 | from joblib.numpy_pickle_utils import _validate_fileobject_and_memmap as _validate_joblib_file |
26 | | -import numpy as np |
27 | 25 | import torch |
28 | | -from lightning.pytorch import Trainer |
29 | | -from omegaconf import DictConfig, open_dict |
30 | 26 | from sklearn.linear_model import LogisticRegression |
31 | 27 | from sklearn.pipeline import Pipeline |
32 | 28 | from sklearn.preprocessing import StandardScaler |
33 | 29 |
|
34 | | -from nemo.collections.asr.models.asr_model import ASRModel |
35 | | -from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel |
36 | | -from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType |
37 | 30 | from nemo.collections.asr.parts.utils.asr_confidence_utils import ( |
38 | 31 | ConfidenceConfig, |
39 | 32 | ConfidenceMethodConfig, |
40 | 33 | get_confidence_aggregation_bank, |
41 | 34 | get_confidence_measure_bank, |
42 | 35 | ) |
43 | 36 | from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis |
44 | | -from nemo.core.classes import ModelPT |
45 | | -from nemo.utils import model_utils |
46 | | -from nemo.utils.decorators import deprecated |
47 | 37 |
|
48 | 38 |
|
49 | 39 | # frozen is required to allow hashing of this class and use it |
@@ -241,191 +231,3 @@ class SecurityError(Exception): |
241 | 231 | """Custom exception for security-related errors.""" |
242 | 232 |
|
243 | 233 | pass |
244 | | - |
245 | | - |
246 | | -@deprecated(version='v2.1.0') |
247 | | -class ConfidenceEnsembleModel(ModelPT): |
248 | | - """Implementation of the confidence ensemble model. |
249 | | -
|
250 | | - See https://arxiv.org/abs/2306.15824 for details. |
251 | | -
|
252 | | - .. note:: |
253 | | - Currently this class only support `transcribe` method as it requires |
254 | | - full-utterance confidence scores to operate. |
255 | | - """ |
256 | | - |
257 | | - def __init__( |
258 | | - self, |
259 | | - cfg: DictConfig, |
260 | | - trainer: 'Trainer' = None, |
261 | | - ): |
262 | | - super().__init__(cfg=cfg, trainer=trainer) |
263 | | - |
264 | | - # either we load all models from ``load_models`` cfg parameter |
265 | | - # or all of them are specified in the config as modelX alongside the num_models key |
266 | | - # |
267 | | - # ideally, we'd like to directly store all models in a list, but that |
268 | | - # is not currently supported by the submodule logic |
269 | | - # so to access all the models, we do something like |
270 | | - # |
271 | | - # for model_idx in range(self.num_models): |
272 | | - # model = getattr(self, f"model{model_idx}") |
273 | | - |
274 | | - if 'num_models' in self.cfg: |
275 | | - self.num_models = self.cfg.num_models |
276 | | - for idx in range(self.num_models): |
277 | | - cfg_field = f"model{idx}" |
278 | | - model_cfg = self.cfg[cfg_field] |
279 | | - model_class = model_utils.import_class_by_path(model_cfg['target']) |
280 | | - self.register_nemo_submodule( |
281 | | - name=cfg_field, |
282 | | - config_field=cfg_field, |
283 | | - model=model_class(model_cfg, trainer=trainer), |
284 | | - ) |
285 | | - else: |
286 | | - self.num_models = len(cfg.load_models) |
287 | | - with open_dict(self.cfg): |
288 | | - self.cfg.num_models = self.num_models |
289 | | - for idx, model in enumerate(cfg.load_models): |
290 | | - cfg_field = f"model{idx}" |
291 | | - if model.endswith(".nemo"): |
292 | | - self.register_nemo_submodule( |
293 | | - name=cfg_field, |
294 | | - config_field=cfg_field, |
295 | | - model=ASRModel.restore_from(model, trainer=trainer, map_location="cpu"), |
296 | | - ) |
297 | | - else: |
298 | | - self.register_nemo_submodule( |
299 | | - cfg_field, |
300 | | - config_field=cfg_field, |
301 | | - model=ASRModel.from_pretrained(model, map_location="cpu"), |
302 | | - ) |
303 | | - |
304 | | - # registering model selection block - this is expected to be a joblib-saved |
305 | | - # pretrained sklearn pipeline containing standardization + logistic regression |
306 | | - # trained to predict "most-confident" model index from the confidence scores of all models |
307 | | - model_selection_block_path = self.register_artifact("model_selection_block", cfg.model_selection_block) |
308 | | - try: |
309 | | - self.model_selection_block = safe_joblib_load(model_selection_block_path) |
310 | | - except SecurityError as e: |
311 | | - raise RuntimeError(f"Security error loading model selection block: {str(e)}") |
312 | | - except Exception as e: |
313 | | - raise RuntimeError(f"Error loading model selection block: {str(e)}") |
314 | | - |
315 | | - self.confidence_cfg = ConfidenceConfig(**self.cfg.confidence) |
316 | | - |
317 | | - # making sure each model has correct temperature setting in the decoder strategy |
318 | | - for model_idx in range(self.num_models): |
319 | | - model = getattr(self, f"model{model_idx}") |
320 | | - # for now we assume users are direclty responsible for matching |
321 | | - # decoder type when building ensemble with inference type |
322 | | - # TODO: add automatic checks for errors |
323 | | - if isinstance(model, EncDecHybridRNNTCTCModel): |
324 | | - self.update_decoding_parameters(model.cfg.decoding) |
325 | | - model.change_decoding_strategy(model.cfg.decoding, decoder_type="rnnt") |
326 | | - self.update_decoding_parameters(model.cfg.aux_ctc.decoding) |
327 | | - model.change_decoding_strategy(model.cfg.aux_ctc.decoding, decoder_type="ctc") |
328 | | - else: |
329 | | - self.update_decoding_parameters(model.cfg.decoding) |
330 | | - model.change_decoding_strategy(model.cfg.decoding) |
331 | | - |
332 | | - def update_decoding_parameters(self, decoding_cfg: DictConfig): |
333 | | - """Updating temperature/preserve_alignment parameters of the config.""" |
334 | | - with open_dict(decoding_cfg): |
335 | | - decoding_cfg.temperature = self.cfg.temperature |
336 | | - decoding_cfg.preserve_alignments = True |
337 | | - |
338 | | - def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): |
339 | | - """Pass-through to the ensemble models. |
340 | | -
|
341 | | - Note that training is not actually supported for this class! |
342 | | - """ |
343 | | - for model_idx in range(self.num_models): |
344 | | - getattr(self, f"model{model_idx}").setup_training_data(train_data_config) |
345 | | - |
346 | | - def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): |
347 | | - """Pass-through to the ensemble models.""" |
348 | | - for model_idx in range(self.num_models): |
349 | | - getattr(self, f"model{model_idx}").setup_validation_data(val_data_config) |
350 | | - |
351 | | - def change_attention_model( |
352 | | - self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True |
353 | | - ): |
354 | | - """Pass-through to the ensemble models.""" |
355 | | - for model_idx in range(self.num_models): |
356 | | - getattr(self, f"model{model_idx}").change_attention_model( |
357 | | - self_attention_model, att_context_size, update_config |
358 | | - ) |
359 | | - |
360 | | - def change_decoding_strategy(self, decoding_cfg: Optional[DictConfig] = None, decoder_type: str = None): |
361 | | - """Pass-through to the ensemble models. |
362 | | -
|
363 | | - The only change here is that we always require expected temperature |
364 | | - to be set as well as ``decoding_cfg.preserve_alignments = True`` |
365 | | - """ |
366 | | - self.update_decoding_parameters(decoding_cfg) |
367 | | - for model_idx in range(self.num_models): |
368 | | - model = getattr(self, f"model{model_idx}") |
369 | | - if isinstance(model, EncDecHybridRNNTCTCModel): |
370 | | - model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type) |
371 | | - else: |
372 | | - model.change_decoding_strategy(decoding_cfg) |
373 | | - |
374 | | - @torch.no_grad() |
375 | | - def transcribe( |
376 | | - self, |
377 | | - paths2audio_files: List[str], |
378 | | - batch_size: int = 4, |
379 | | - return_hypotheses: bool = False, |
380 | | - num_workers: int = 0, |
381 | | - channel_selector: Optional[ChannelSelectorType] = None, |
382 | | - augmentor: DictConfig = None, |
383 | | - verbose: bool = True, |
384 | | - **kwargs, # any other model specific parameters are passed directly |
385 | | - ) -> List[str]: |
386 | | - """Confidence-ensemble transcribe method. |
387 | | -
|
388 | | - Consists of the following steps: |
389 | | -
|
390 | | - 1. Run all models (TODO: in parallel) |
391 | | - 2. Compute confidence for each model |
392 | | - 3. Use logistic regression to pick the "most confident" model |
393 | | - 4. Return the output of that model |
394 | | - """ |
395 | | - confidences = [] |
396 | | - all_transcriptions = [] |
397 | | - # always requiring to return hypothesis |
398 | | - # TODO: make sure to return text only if was False originally |
399 | | - return_hypotheses = True |
400 | | - for model_idx in range(self.num_models): |
401 | | - model = getattr(self, f"model{model_idx}") |
402 | | - transcriptions = model.transcribe( |
403 | | - paths2audio_files=paths2audio_files, |
404 | | - batch_size=batch_size, |
405 | | - return_hypotheses=return_hypotheses, |
406 | | - num_workers=num_workers, |
407 | | - channel_selector=channel_selector, |
408 | | - augmentor=augmentor, |
409 | | - verbose=verbose, |
410 | | - **kwargs, |
411 | | - ) |
412 | | - if isinstance(transcriptions, tuple): # transducers return a tuple |
413 | | - transcriptions = transcriptions[0] |
414 | | - |
415 | | - model_confidences = [] |
416 | | - for transcription in transcriptions: |
417 | | - model_confidences.append(compute_confidence(transcription, self.confidence_cfg)) |
418 | | - confidences.append(model_confidences) |
419 | | - all_transcriptions.append(transcriptions) |
420 | | - |
421 | | - # transposing with zip(*list) |
422 | | - features = np.array(list(zip(*confidences))) |
423 | | - model_indices = self.model_selection_block.predict(features) |
424 | | - final_transcriptions = [] |
425 | | - for transcrption_idx in range(len(all_transcriptions[0])): |
426 | | - final_transcriptions.append(all_transcriptions[model_indices[transcrption_idx]][transcrption_idx]) |
427 | | - |
428 | | - return final_transcriptions |
429 | | - |
430 | | - def list_available_models(self): |
431 | | - return [] |
0 commit comments