Skip to content

Commit 863fefe

Browse files
hoangtch-namitechAmir Hussein
authored andcommitted
Fix decoding with ngpu-lm when training (NVIDIA-NeMo#13994) (NVIDIA-NeMo#13995)
* Fix decoding with ngpu-lm when training (NVIDIA-NeMo#13994) Signed-off-by: Hoang Tran <hoang.tch@namitech.io> * code_format Signed-off-by: Hoang Tran <hoang.tch@namitech.io> --------- Signed-off-by: Hoang Tran <hoang.tch@namitech.io> Signed-off-by: Amir Hussein <amhussein@nvidia.com>
1 parent 2bd5a4b commit 863fefe

File tree

4 files changed

+40
-4
lines changed

4 files changed

+40
-4
lines changed

nemo/collections/asr/parts/submodules/ctc_beam_decoding.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from nemo.collections.asr.parts.submodules.ngram_lm import DEFAULT_TOKEN_OFFSET
2727
from nemo.collections.asr.parts.submodules.wfst_decoder import RivaDecoderConfig, WfstNbestHypothesis
2828
from nemo.collections.asr.parts.utils import rnnt_utils
29+
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
2930
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
3031
from nemo.core.classes import Typing, typecheck
3132
from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType
@@ -878,7 +879,7 @@ def _k2_decoding(self, x: torch.Tensor, out_len: torch.Tensor) -> List['WfstNbes
878879
return self.k2_decoder.decode(x.to(device=self.device), out_len.to(device=self.device))
879880

880881

881-
class BeamBatchedCTCInfer(AbstractBeamCTCInfer):
882+
class BeamBatchedCTCInfer(AbstractBeamCTCInfer, WithOptionalCudaGraphs):
882883
"""
883884
A batched beam CTC decoder.
884885
@@ -945,6 +946,16 @@ def __init__(
945946
allow_cuda_graphs=allow_cuda_graphs,
946947
)
947948

949+
def disable_cuda_graphs(self):
950+
"""Disable CUDA graphs (e.g., for decoding in training)"""
951+
if isinstance(self.search_algorithm, WithOptionalCudaGraphs):
952+
self.search_algorithm.disable_cuda_graphs()
953+
954+
def maybe_enable_cuda_graphs(self):
955+
"""Enable CUDA graphs (if allowed)"""
956+
if isinstance(self.search_algorithm, WithOptionalCudaGraphs):
957+
self.search_algorithm.maybe_enable_cuda_graphs()
958+
948959
@typecheck()
949960
def forward(
950961
self,

nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
2424
from nemo.collections.asr.parts.utils import rnnt_utils
2525
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
26+
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
2627
from nemo.core.classes import Typing, typecheck
2728
from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType
2829
from nemo.core.utils.cuda_python_utils import (
@@ -389,7 +390,7 @@ def __call__(self, *args, **kwargs):
389390
return self.forward(*args, **kwargs)
390391

391392

392-
class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin):
393+
class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin, WithOptionalCudaGraphs):
393394
"""A vectorized greedy CTC decoder.
394395
395396
This is basically always faster than GreedyCTCInfer, and supports
@@ -500,6 +501,8 @@ def __init__(
500501
self.ngram_lm_alpha = ngram_lm_alpha
501502
self.state: CTCDecoderCudaGraphsState | None = None
502503
else:
504+
self.allow_cuda_graphs = False
505+
self.cuda_graphs_mode = None
503506
self.ngram_lm_batch = None
504507

505508
@typecheck()

nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
is_prefix,
4949
select_k_expansions,
5050
)
51+
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
5152
from nemo.core.classes import Typing, typecheck
5253
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
5354
from nemo.utils import logging
@@ -1526,7 +1527,7 @@ def set_decoding_type(self, decoding_type: str):
15261527
self.token_offset = DEFAULT_TOKEN_OFFSET
15271528

15281529

1529-
class BeamBatchedRNNTInfer(Typing, ConfidenceMethodMixin):
1530+
class BeamBatchedRNNTInfer(Typing, ConfidenceMethodMixin, WithOptionalCudaGraphs):
15301531
@property
15311532
def input_types(self):
15321533
"""Returns definitions of module input ports."""
@@ -1636,6 +1637,16 @@ def __init__(
16361637
allow_cuda_graphs=allow_cuda_graphs,
16371638
)
16381639

1640+
def disable_cuda_graphs(self):
1641+
"""Disable CUDA graphs (e.g., for decoding in training)"""
1642+
if isinstance(self._decoding_computer, WithOptionalCudaGraphs):
1643+
self._decoding_computer.disable_cuda_graphs()
1644+
1645+
def maybe_enable_cuda_graphs(self):
1646+
"""Enable CUDA graphs (if allowed)"""
1647+
if isinstance(self._decoding_computer, WithOptionalCudaGraphs):
1648+
self._decoding_computer.maybe_enable_cuda_graphs()
1649+
16391650
@property
16401651
def output_types(self):
16411652
"""Returns definitions of module output ports."""

nemo/collections/asr/parts/submodules/tdt_beam_decoding.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
4040
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BlankLMScoreMode, PruningMode
4141
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix
42+
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
4243
from nemo.core.classes import Typing, typecheck
4344
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
4445
from nemo.utils import logging
@@ -829,7 +830,7 @@ def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
829830
return sorted(hyps, key=lambda x: x.score, reverse=True)
830831

831832

832-
class BeamBatchedTDTInfer(Typing, ConfidenceMethodMixin):
833+
class BeamBatchedTDTInfer(Typing, ConfidenceMethodMixin, WithOptionalCudaGraphs):
833834
@property
834835
def input_types(self):
835836
"""Returns definitions of module input ports."""
@@ -910,6 +911,16 @@ def __init__(
910911
else:
911912
raise Exception(f"Decoding strategy {search_type} nor implemented.")
912913

914+
def disable_cuda_graphs(self):
915+
"""Disable CUDA graphs (e.g., for decoding in training)"""
916+
if isinstance(self._decoding_computer, WithOptionalCudaGraphs):
917+
self._decoding_computer.disable_cuda_graphs()
918+
919+
def maybe_enable_cuda_graphs(self):
920+
"""Enable CUDA graphs (if allowed)"""
921+
if isinstance(self._decoding_computer, WithOptionalCudaGraphs):
922+
self._decoding_computer.maybe_enable_cuda_graphs()
923+
913924
@property
914925
def output_types(self):
915926
"""Returns definitions of module output ports."""

0 commit comments

Comments
 (0)