Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,13 @@ def forward_for_export(self, x, lengths):
# Create sample mask - 1 represents value, 0 represents pad
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
mask = ~mask # 0 represents value, 1 represents pad
x = x.float() # For stable AMP, SE must be computed at fp32.

# Ensure SE runs in FP32: cast fc weights and activations to float32
if self.fc[0].weight.dtype != torch.float32:
self.fc.float()
if x.dtype != torch.float32:
x = x.float()

x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
y = self._se_pool_step(x, mask) # [B, C, 1]
y = y.transpose(1, -1) # [B, 1, C]
Expand All @@ -490,6 +496,8 @@ def forward_for_export(self, x, lengths):

y = torch.sigmoid(y)
y = x * y
# Cast back to original dtype for downstream consistency
y = y.to(dtype)
return y, lengths

def _se_pool_step(self, x, mask):
Expand Down
2 changes: 1 addition & 1 deletion tutorials/llm/reasoning/Reasoning-SFT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"* A valid Hugging Face API token with access to the [Meta LLaMa 3.1-8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model (since this is a gated model).\n",
"\n",
"### Dataset\n",
"To follow along, you would need an appropriate reasoning dataset. Checkout the tutorial on [curating the Llama Nemotron Reasoning Dataset with NVIDIA NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/llama-nemotron-data-curation).\n",
"To follow along, you would need an appropriate reasoning dataset. Checkout the tutorial on [curating the Llama Nemotron Reasoning Dataset with NVIDIA NeMo Curator](https://github.com/NVIDIA-NeMo/Curator/tree/dask/tutorials/llama-nemotron-data-curation).\n",
"You will need the output from that tutorial as the training set input to this playbook!\n",
"\n",
"### Hardware Requirements\n",
Expand Down
Loading