Skip to content

Commit 70d0d10

Browse files
cast SE weights and activations to fp32
1 parent 1b43a27 commit 70d0d10

File tree

1 file changed

+7
-3
lines changed
  • nemo/collections/asr/parts/submodules

1 file changed

+7
-3
lines changed

nemo/collections/asr/parts/submodules/jasper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,11 @@ def forward_for_export(self, x, lengths):
478478
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
479479
mask = ~mask # 0 represents value, 1 represents pad
480480

481-
# Commented out the below cast in v2.5.0 to fix dtype errors when running examples/asr/transcribe_speech.py on ASR models that use jasper.py encoder.
482-
# Observed minimal changes in model outputs from this change.
483-
# x = x.float()
481+
# Ensure SE runs in FP32: cast fc weights and activations to float32
482+
if self.fc[0].weight.dtype != torch.float32:
483+
self.fc.float()
484+
if x.dtype != torch.float32:
485+
x = x.float()
484486

485487
x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
486488
y = self._se_pool_step(x, mask) # [B, C, 1]
@@ -494,6 +496,8 @@ def forward_for_export(self, x, lengths):
494496

495497
y = torch.sigmoid(y)
496498
y = x * y
499+
# Cast back to original dtype for downstream consistency
500+
y = y.to(dtype)
497501
return y, lengths
498502

499503
def _se_pool_step(self, x, mask):

0 commit comments

Comments
 (0)