diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index ec0def1b3ebb..f29332ea522b 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -477,7 +477,13 @@ def forward_for_export(self, x, lengths): # Create sample mask - 1 represents value, 0 represents pad mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device) mask = ~mask # 0 represents value, 1 represents pad - x = x.float() # For stable AMP, SE must be computed at fp32. + + # Ensure SE runs in FP32: cast fc weights and activations to float32 + if self.fc[0].weight.dtype != torch.float32: + self.fc.float() + if x.dtype != torch.float32: + x = x.float() + x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0 y = self._se_pool_step(x, mask) # [B, C, 1] y = y.transpose(1, -1) # [B, 1, C] @@ -490,6 +496,8 @@ def forward_for_export(self, x, lengths): y = torch.sigmoid(y) y = x * y + # Cast back to original dtype for downstream consistency + y = y.to(dtype) return y, lengths def _se_pool_step(self, x, mask): diff --git a/tutorials/llm/reasoning/Reasoning-SFT.ipynb b/tutorials/llm/reasoning/Reasoning-SFT.ipynb index 78e67bbf2541..f0a2fd8dc609 100644 --- a/tutorials/llm/reasoning/Reasoning-SFT.ipynb +++ b/tutorials/llm/reasoning/Reasoning-SFT.ipynb @@ -32,7 +32,7 @@ "* A valid Hugging Face API token with access to the [Meta LLaMa 3.1-8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model (since this is a gated model).\n", "\n", "### Dataset\n", - "To follow along, you would need an appropriate reasoning dataset. Checkout the tutorial on [curating the Llama Nemotron Reasoning Dataset with NVIDIA NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/llama-nemotron-data-curation).\n", + "To follow along, you would need an appropriate reasoning dataset. Checkout the tutorial on [curating the Llama Nemotron Reasoning Dataset with NVIDIA NeMo Curator](https://github.com/NVIDIA-NeMo/Curator/tree/dask/tutorials/llama-nemotron-data-curation).\n", "You will need the output from that tutorial as the training set input to this playbook!\n", "\n", "### Hardware Requirements\n",