Switch to PyTorch's built-in RMSNorm#2054
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2054
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ecf9748 with merge base a9aadf5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3) | ||
| assert output_fp16.dtype == torch.float32 | ||
| assert output_fp16.dtype == torch.float16 |
There was a problem hiding this comment.
do you know why this wasnt failing before?
There was a problem hiding this comment.
yeah the model wasn't casted to fp16, which means the scale parameter was still fp32. And since x_normed * self.scale occurred after the cast back to fp16, the output ended up in fp32.
| weight=self.scale, | ||
| eps=self.eps, |
There was a problem hiding this comment.
noob question: when we load the model in bf16, will self.eps and self.scale also become bf16 or do they stay float32?
If they are cast to bf16, its might be worth digging a bit
There was a problem hiding this comment.
eps is just a python float, not a pytorch tensor, so it'll basically just adapt to whichever dtype it is being applied to.
scale will become bf16, but F.rms_norm will cast it to fp32 when it gets multiplied by the fp32 output.
There was a problem hiding this comment.
nice finding! Approving to unblock, just left a few comments wanting to make sure that the test and the bf16 are ok.
also, fyi, i think that using torch.cuda api would be the way to go here.
class Time:
def __init__(self, name):
self.name = name
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
def __enter__(self):
self.start_event.record()
def __exit__(self, *args, **kwargs):
self.end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded
elapsed_time = self.start_event.elapsed_time(
self.end_event
) # Time in milliseconds
print(f"TIME_{self.name}: {elapsed_time:.3f} ms")
using shape = [8, 512, 4096], i got
TIME_inference_uncompiled1: 129.710 ms
TIME_inference_uncompiled2: 11.391 ms
inference uncompiled mse 0.0
TIME_inference_initial_compile1: 2834.852 ms
TIME_inference_initial_compile2: 403.543 ms
TIME_inference_compiled1: 1.052 ms
TIME_inference_compiled2: 0.869 ms
inference compiled mse 2.7247249363426818e-06
TIME_train_uncompiled1: 40.218 ms
TIME_train_uncompiled2: 10.865 ms
train uncompiled mse 0.0
TIME_train_initial_compile1: 928.825 ms
TIME_train_initial_compile2: 708.388 ms
TIME_train_compiled1: 6.275 ms
TIME_train_compiled2: 5.314 ms
train compiled mse 2.7247249363426818e-06
This reverts commit 1450d61.
Context
What is the purpose of this PR? Is it to
Results
Inference
Uncompiled: 31x faster
Compiled: compilation and first batch = 5x faster, post-first batches = 1.5x faster
Train (forward + backward):
Uncompiled: 22x faster
Compiled: compilation and first batch = 1.5x faster, post-first batches = 1.2x faster
Parity:
Uncompiled: perfect
Compiled: 2.76e-6 MSE
Test plan
Test code (output is below):
1 = original RMSNorm, 2 = new RMSNorm
Output on 1xH100: