Add Parakeet Hybrid RNNT CTC BPE Model with target language support#13360
Add Parakeet Hybrid RNNT CTC BPE Model with target language support#13360ealbasiri wants to merge 19 commits intoNVIDIA-NeMo:mainfrom
Conversation
nithinraok
left a comment
There was a problem hiding this comment.
Overall:
- I think the implementation can be simplified.
- Implementation of langid can be generalized for other usecases as well
- Unit tests are missing for the current implementation
There was a problem hiding this comment.
instead of ast, Its better to align this PR in the lines of prompt_transducer so keep in asr_hybrid_transducer_ctc but change name to speech_to_text_hybrid_rnnt_ctc_bpe_prompt.py. Which means this model code need to be generalized for other use cases as well.
| def output_types(self) -> Optional[Dict[str, NeuralType]]: | ||
| return { | ||
| 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), | ||
| 'a_sig_length': NeuralType(tuple('B'), LengthsType()), |
There was a problem hiding this comment.
keep full name-> audio_signal_length
nemo/collections/asr/data/audio_to_text_lhotse_target_language.py
Outdated
Show resolved
Hide resolved
| # Convert to Hydra 1.0 compatible DictConfig | ||
| cfg = model_utils.convert_model_config_to_dict_config(cfg) | ||
| cfg = model_utils.maybe_update_config_version(cfg) | ||
| self._GLOBAL_LANG_MAP = { |
There was a problem hiding this comment.
move this to end of file and keep in global space. (for better readability)
| """ | ||
| # Preparing the Tokenizer for the dataset | ||
| Use the `process_asr_text_tokenizer.py` script under <NEMO_ROOT>/scripts/tokenizers/ in order to prepare the tokenizer. | ||
|
|
There was a problem hiding this comment.
Add a sample manifest line here how it would look for added lang ID support
There was a problem hiding this comment.
example config file is missing. Also add how a sample manifest file would look.
Based on your if cases and everything. Do you really need a new model for this? Why not adjust existing hyb model code? Other than BLEU support I don;t see much use with this new model file. Correct me otherwise.
There was a problem hiding this comment.
However if you adjust the code to support for general purpose: like if users want to use this for example add predict token to output it would be great!
There was a problem hiding this comment.
As discussed, i think leaving it separate is better for simplicity for now as it's still at developmental stage and potentially, we want to integrate it with Target speaker ASR so we have a model Target speaker ASR/AST. Second reason is that the idea of concatenating target lang ID as a one hot vector to ASR embedding and directly feeding that into the decoder could conceptually apply to other type of decoders such as transformer. So, perhaps it's easier to manipulate if it's separate from current Hybrid Parakeet integration.
examples/asr/transcribe_speech.py
Outdated
| # Special case for EncDecHybridRNNTCTCBPEModelTgtLangID | ||
| if isinstance(asr_model, EncDecHybridRNNTCTCBPEModelTgtLangID): | ||
| # Special case for EncDecHybridRNNTCTCBPEModelTgtLangID, where the input manifest is directly passed into the model's transcribe() function | ||
| filepaths = cfg.dataset_manifest | ||
| assert ( | ||
| cfg.dataset_manifest is not None | ||
| ), "dataset_manifest must be provided for EncDecHybridRNNTCTCBPEModelTgtLangID" | ||
| sorted_manifest_path = None | ||
| else: | ||
| filepaths, sorted_manifest_path = prepare_audio_data(cfg) |
There was a problem hiding this comment.
remove this add support to .transcribe() with a default target_language_id support. See AED models code.
nemo/collections/asr/data/audio_to_text_lhotse_target_language.py
Outdated
Show resolved
Hide resolved
| num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. | ||
| num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. |
There was a problem hiding this comment.
Better to rename these to shift and subsampling factor.
There was a problem hiding this comment.
Doesn;t matter, for ASR update accordingly.
| lang_targets = [ | ||
| torch.transpose( | ||
| torch.as_tensor( | ||
| self.lang_to_target_lang( | ||
| c, | ||
| self.num_languages, | ||
| self.num_sample_per_mel_frame, | ||
| self.num_mel_frame_per_asr_frame, | ||
| ), | ||
| dtype=torch.float32, | ||
| ), | ||
| 0, | ||
| 1, | ||
| ) | ||
| for c in cuts | ||
| ] | ||
|
|
||
| # Create final tensors |
There was a problem hiding this comment.
why to do this here and not part of training loop? That way you could avoid preparing these matrices and basically avoid this dataset all together.
nemo/collections/asr/data/audio_to_text_lhotse_target_language.py
Outdated
Show resolved
Hide resolved
|
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
|
Revisions have been integrated. Please review. Thank you! |
nithinraok
left a comment
There was a problem hiding this comment.
- As mentioned before
unit testsare missing - I am pretty sure current script doesn;t run without any errors with config provided, correct them and add a CI-CD run for this script.
|
|
||
|
|
||
| @hydra_runner( | ||
| config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc_bpe" |
There was a problem hiding this comment.
Shouldn;t the config be: examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml ?
There was a problem hiding this comment.
I see its not updated yet. Marking to unresolved
...conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml
Outdated
Show resolved
Hide resolved
...conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml
Outdated
Show resolved
Hide resolved
| num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. | ||
| num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. |
There was a problem hiding this comment.
Doesn;t matter, for ASR update accordingly.
| @property | ||
| def input_types(self) -> Optional[Dict[str, NeuralType]]: | ||
| if hasattr(self.preprocessor, '_sample_rate'): | ||
| input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) |
There was a problem hiding this comment.
Thats your branch? I don;t see your point
| target_lang_id = target_lang_id[:, : encoded.shape[1], :] | ||
|
|
||
| # Concatenate encoded states with language ID | ||
| concat_enc_states = torch.cat([encoded, target_lang_id], dim=-1) |
There was a problem hiding this comment.
but we thought the language signal might be less evident at some time steps resulting in wrong language predictions.
Why?
| return {'loss': loss_value} | ||
|
|
||
| def predict_step(self, batch, batch_idx, dataloader_idx=0): | ||
| # TODO: add support for CTC decoding |
There was a problem hiding this comment.
Add support for CTC decoding, its was added for other hyb models as well. See those classes for reference
| self.bleu.update( | ||
| predictions=encoded, predictions_lengths=encoded_len, targets=transcript, targets_lengths=transcript_len | ||
| ) | ||
| bleu_metrics = self.bleu.compute(return_all_metrics=True, prefix="val_") | ||
| tensorboard_logs.update( | ||
| { | ||
| 'val_bleu_num': bleu_metrics['val_bleu_num'], | ||
| 'val_bleu_denom': bleu_metrics['val_bleu_denom'], | ||
| 'val_bleu_pred_len': bleu_metrics['val_bleu_pred_len'], | ||
| 'val_bleu_target_len': bleu_metrics['val_bleu_target_len'], | ||
| 'val_bleu': bleu_metrics['val_bleu'], | ||
| } | ||
| ) | ||
| self.bleu.reset() |
There was a problem hiding this comment.
Why to calculate bleu scores when you perform asr task ?
There was a problem hiding this comment.
Because the model performs both ASR and AST without specific prompt distinction for the task (ASR/AST)
There was a problem hiding this comment.
does that make sense? What is the alternative?
| class EncDecHybridRNNTCTCBPEModelTgtLangID(EncDecHybridRNNTCTCModel, ASRBPEMixin, ASRTranscriptionMixin): | ||
| """Base class for encoder decoder RNNT-based models with auxiliary CTC decoder/loss and subword tokenization.""" | ||
|
|
||
| def __init__(self, cfg: DictConfig, trainer: Trainer = None): |
There was a problem hiding this comment.
where is .transcribe() method?
|
Also rebase with main for latest changes. |
df8c0fb to
967eb94
Compare
967eb94 to
2afbd35
Compare
...conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml
Show resolved
Hide resolved
| trainer: | ||
| devices: -1 # number of GPUs, -1 would use all available GPUs | ||
| num_nodes: 1 | ||
| max_epochs: 1000 | ||
| max_steps: -1 # computed at runtime if not set | ||
| val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
| accelerator: auto | ||
| strategy: ddp | ||
| accumulate_grad_batches: 1 | ||
| gradient_clip_val: 1.0 | ||
| precision: bf16 # 16, 32, or bf16 | ||
| log_every_n_steps: 10 # Interval of logging. | ||
| enable_progress_bar: True | ||
| num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
| check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
| sync_batchnorm: true | ||
| enable_checkpointing: False # Provided by exp_manager | ||
| logger: false # Provided by exp_manager | ||
| benchmark: false # needs to be false for models with variable-length speech input as it slows down training |
There was a problem hiding this comment.
Is this correct for lhotse based dataloader? Double check
nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py
Outdated
Show resolved
Hide resolved
| from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models_prompt import EncDecHybridRNNTCTCBPEModelWithPrompt | ||
|
|
||
| # Global map of language codes to indices that existed in the old model | ||
| GLOBAL_LANG_MAP = { |
There was a problem hiding this comment.
is this required now? since you added to cfg
There was a problem hiding this comment.
apologies, this whole script should be deleted from the branch. I'll update
|
|
||
|
|
||
| # For backward compatibility | ||
| class EncDecHybridRNNTCTCBPEModelTgtLangID(EncDecHybridRNNTCTCBPEModelWithPrompt): |
There was a problem hiding this comment.
Remove this we don;t need anymore no?
| return hypothesis | ||
|
|
||
|
|
||
| class BatchedFrameASRTDT(BatchedFrameASRRNNT): |
nithinraok
left a comment
There was a problem hiding this comment.
- Add .transcribe() tests
- Add timestamp check test
- CI-CD run test
...conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml
Show resolved
Hide resolved
| devices: -1 # number of GPUs, -1 would use all available GPUs | ||
| num_nodes: 1 | ||
| max_epochs: -1 | ||
| max_steps: 10000 | ||
| val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
| accelerator: auto | ||
| strategy: ddp | ||
| accumulate_grad_batches: 1 | ||
| gradient_clip_val: 1.0 | ||
| precision: bf16 # 16, 32, or bf16 | ||
| log_every_n_steps: 10 # Interval of logging. | ||
| enable_progress_bar: True | ||
| num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
| check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
| sync_batchnorm: true | ||
| enable_checkpointing: False # Provided by exp_manager | ||
| logger: false # Provided by exp_manager | ||
| benchmark: false # needs to be false for models with variable-length speech input as it slows down training | ||
| use_distributed_sampler: false |
There was a problem hiding this comment.
lhotse parameters are not updated here yet, this will error out/gets stuck
| not NUMBA_RNNT_LOSS_AVAILABLE, | ||
| reason='RNNTLoss has not been compiled with appropriate numba version.', | ||
| ) | ||
| @pytest.mark.with_downloads() |
There was a problem hiding this comment.
why to run with_downloads(). which tests are based on Test Folder?
f9d5379 to
a59043a
Compare
2047fb7 to
c6b3215
Compare
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: ealbasiri <[email protected]> Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
caee120 to
b9bd2b4
Compare
Signed-off-by: Enas Albasiri <[email protected]>
Signed-off-by: Enas Albasiri <[email protected]>
nithinraok
left a comment
There was a problem hiding this comment.
Overall looks fine now.
Few things:
- Update integration test as mentioned
- Add documentation
| @@ -0,0 +1,372 @@ | |||
| # The model would have two decoders: RNNT (Transducer) and CTC | |||
There was a problem hiding this comment.
Add a manifest example line here on how it should look for training this model
| python -c "from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModelWithPrompt" && \ | ||
| NEMO_NUMBA_MINVER=0.53 CUDA_VISIBLE_DEVICES=0 \ | ||
| coverage run -a --data-file=/workspace/.coverage --source=/workspace/ \ | ||
| -m pytest tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe_prompt.py \ | ||
| -v No newline at end of file |
There was a problem hiding this comment.
This is not for running tests, this is for integration see example and run for few steps only: https://github.com/ealbasiri/NeMo/blob/hybrid-parakeet-tgt-lang-apr30/tests/functional_tests/ASR_dev_run_Speech_to_Text.sh
|
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
|
This PR was closed because it has been inactive for 7 days since being marked as stale. |
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Collection: [Note which collection this PR will affect]
ASR
Changelog
This PR changes the following:
Usage
This model can be used for ASR/AST tasks as well as word time-stamp generation for downstream application.
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
@nithinraok
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information