Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/torchtune/modules/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
# LICENSE file in the root directory of this source tree.

import pytest

import torch

from tests.test_utils import assert_expected
from torch.nn.functional import normalize

from torchtune.modules.rms_norm import RMSNorm
from torchtune.training.seed import set_seed

Expand Down Expand Up @@ -66,6 +64,7 @@ def test_forward_fp16(self, rms_norm, input_random_fp16, dim) -> None:

# convert input to float since rms_norm computes in fp32
expected_fp16 = normalize(input_random_fp16.float(), p=2, dim=-1) * (dim**0.5)
expected_fp16 = expected_fp16.to(torch.float16)

assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3)
assert output_fp16.dtype == torch.float32
assert output_fp16.dtype == torch.float16

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know why this wasnt failing before?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the model wasn't casted to fp16, which means the scale parameter was still fp32. And since x_normed * self.scale occurred after the cast back to fp16, the output ended up in fp32.

21 changes: 10 additions & 11 deletions torchtune/modules/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
# LICENSE file in the root directory of this source tree.

import torch

import torch.nn.functional as F
from torch import nn


class RMSNorm(nn.Module):
"""
Implements Root Mean Square Normalization introduced in
https://arxiv.org/abs/1910.07467.
Root Mean Square Normalization in fp32.

Reference implementation (used for correctness verification)
can be found here:
https://github.com/facebookresearch/llama/blob/main/llama/model.py
See: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html

Args:
dim (int): embedding size
Expand All @@ -25,6 +22,7 @@ class RMSNorm(nn.Module):

def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.normalized_shape = (dim,)
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))

Expand All @@ -37,8 +35,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor: The normalized and scaled tensor having the same shape as ``x``.
"""
# computation is in fp32
x_fp32 = x.float()
x_normed = (
x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
).type_as(x)
return x_normed * self.scale
return F.rms_norm(
x.float(),
normalized_shape=self.normalized_shape,
weight=self.scale,
eps=self.eps,
Comment on lines +41 to +42

@felipemello1 felipemello1 Nov 22, 2024

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: when we load the model in bf16, will self.eps and self.scale also become bf16 or do they stay float32?

If they are cast to bf16, its might be worth digging a bit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps is just a python float, not a pytorch tensor, so it'll basically just adapt to whichever dtype it is being applied to.

scale will become bf16, but F.rms_norm will cast it to fp32 when it gets multiplied by the fp32 output.

).to(x.dtype)