-
Notifications
You must be signed in to change notification settings - Fork 683
Enable loss parallel, Ungate FP8 #2782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable loss parallel, Ungate FP8 #2782
Conversation
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2782
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e275c88 with merge base 5b2e881 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: Nathan Azrak <[email protected]>
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @nathan-az, these memory savings are really impressive! Lmk if any of the comments are unclear
|
|
||
| # TODO: expose this once tested | ||
| def _fp8_llama_tp_plan() -> dict[str, ParallelStyle]: | ||
| def fp8_llama_tp_plan( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be missing something here, but are we actually enabling this now? Like if I set enable_fp8_training=True and tensor_parallel_plan=base_llama_tp_plan, where do we actually hook this up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been manually changing tensor_parallel_plan to torchtune.models.llama3.fp8_llama_tp_plan. Good call though - that's bad UX and we can just select the correct plan based on if fp8 is enabled. Will fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commit addresses this. Required a bit of a refactor. Have also applied to LLaMA-4 so it raises an error more elegantly if someone tries FP8 + LLaMA-4 training.
Resolve this comment if you're happy with the solution :)
| reduction="sum", | ||
| ignore_index=self.ignore_index, | ||
| ) | ||
| # the all-reduce later complains if a DTensor is returned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly the allreduce I'm referring to. And yes, the loss parallel CE loss does return a Replicate DTensor. Without the full_tensor, when we reach the all-reduce in the recipe, we get:
[rank0]: AssertionError: found no DeviceMesh from dtensor args for c10d.allreduce_.default!
The loss that comes out is a DTensor, but is a Replicate so I saw no difference between using full_tensor and using to_local in the TP case, versus not using loss parallel at all. Differences to DP (I use as a baseline) are very small. This is without packing, and adjusting effective batch size.
I did notice that DP+TP+CP appears to have more drastic differences between grad norms and losses versus DP. I expect this is probably due to the grad norm scaling, not the loss parallelism, although CP with/without loss parallel does exhibit slight differences.
note: I tried to keep the run names short. dp refers to dp8, cp or tp use both dim 2, with shard dim and batch size adjusted accordingly.
Suggestion: Unless you have a clearer idea, I'd lean towards leaving this full_tensor in, then follow up to investigate the CP difference.
I won't resolve this comment yet, but please do so if you're happy to leave this as-is for now.
| layerwise_colwise_parallel_cls: type[ParallelStyle] = ColwiseParallel, | ||
| layerwise_rowwise_parallel_cls: type[ParallelStyle] = RowwiseParallel, | ||
| layerwise_prepare_module_input_cls: type[ParallelStyle] = PrepareModuleInput, | ||
| loss_parallel: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been really trying to avoid this with the new loss functions and instead been having the loss functions modify the models directly. Otherwise we have to maintain a special TP plan for every kind of loss (see ligerloss). For the liger loss case you can just call full_tensor but I believe you can modify a model's output_layout after the fact too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mm yes this is a reasonable point. The current pattern was taken from torchtitan but we will soon support more loss functions than them.
We have two slightly tricky requirements:
- Our training step needs to be aware of whether to enter the loss parallel context manager. For ones that use the standard pytorch cross entropy loss, this is the case. For ones like liger, probably not, even if they support loss parallelism
- We should probably modify the TP plan based on loss parallelism rather than using
redistributein the forward step, since this would introduce additional collective (pytorch might be smart enough to optimise out, consecutive redistributes, but I'm not sure).
I think this is antithesis to what you're suggesting, but I wonder about relying more on the loss function here, using:
- a method
patch_tp_planwhich makes any required modifications from the base tp plan prior to parallelisation in the model setup (as simple as adict.update), which we default to a noop - a property
use_loss_parallel_context_manager, defaulted to false, which indicates whether to use the context manager in the loss parallel case or not, which we use in the training loop
I'll address simple comments first, then implement this so it's easy to review, and simple to revert if you don't like it.
My rationale is that it turns loss parallelism into a first class citizen, and centralises any loss parallel-specific functionality to the loss class itself. Modifying the model I think will end up less straightforward depending on how loss is parallelised.
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
|
Converting to draft, will update according to #2824 once it merges then request review again. |
Signed-off-by: Nathan Azrak <[email protected]>
|
I've made some changes to enable loss compilation in the non-loss parallel case. This did require taking some logic out of the function that is compiled, but it is only masking and slicing which I believe should be pretty trivial anyway, from both the memory and compute perspectives. I've confirmed the gains in memory utilisation against the standard PyTorch cross-entropy loss, as well as aligning loss curves.
Above is testing with LLaMA 3.1 8B, seq len 2**13, adjusting batch size for DP size, with 16 chunks (except "baseline" which does not use chunking). Very happy to see notable improvements in memory usage with the chunking, and more with loss parallelism, as well as minor tokens-per-second improvements from compile (in the DP case) and loss parallelism in the TP case. We can experiment in the future with trying to get compile working in the TP case, but at least for now this is improved from the current state where it is fully disabled (in the DP case, I see 20% reduced peak memory usage currently). p.s. this PR now takes heavily from @felipemello1's work #2824. Would be good to get felipe's thoughts too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey Nathan, thanks for the PR. I will leave the parallelism implementation to @ebsmothers and @pbontrager, since they were already on top of it. Regarding the loss, i like that you were able to port the other PR! Thanks for making it more modular and enabling compile. My two cents:
- In which scenarios would we want to have
mask_ignored_tokens=False? If there isnt a strong one, maybe we should remove the flag - I personally dont like how SFTLoss has many args related to parallelism:
a) 'supports_loss_parallel',
b) 'loss_parallel_enabled',
c) 'loss_paralell_requires_ctx_manager'.
d) 'use_loss_parallel_ctx_manager'
Maybe there is a way to simplify it? e.g. merge a) and c) and make b) and d) always True if input is TP? I am ok with having only way of doing things if it works for most of the cases
|
@felipemello1 fair, I erred on the side of leaving options to the users.
Basically to have an option to enable compile, in cases where very few tokens are masked on average. When packing to very high seq len, if samples aren't very long, packing should be very effective, and very few tokens could be masked. Thus, the user may see more gain by enabling compile. Once we have
Agreed, this is a lot, again, for the sake of adding options for users. You're right - there's no clear downside to LP, we can just default to it with TP. I've just pushed another commit. It reduces the above to only one class-level property: |
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
Signed-off-by: Nathan Azrak <[email protected]>
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK no remaining concerns from my side, this looks good to me. We need to get CI fixed (@felipemello1 is working on this in #2841 and elsewhere), once we can get a clean CI run in this should be good to merge.
|
please merge main so tests can pass |

Context
What is the purpose of this PR? Is it to
Changelog
What are the changes made in this PR?
Loss parallelism is the main feature. Loss curves look healthy, but memory utilisation is significantly lower now. This is also compatible (tested) with FP8.
Peak active memory usage scales aggressively (linearly) has tensor parallelism increases without loss parallelism
Peak active memory usage scales much more generously (still linear) with loss parallelism enabled
In the tp8 case (bs=8, seq len=8192), loss parallelism decreases active memory peak by about 30GB, about 35%. TPS remains the same.
Compiling autograd doesn't work(removed autograd compile for now)Unfortunately compiling autograd doesn't seem to work and throws an FSDP error that I don't know how to debug. The error seems to indicate it's caused by some graph break during checkpointing. We can remove this from the PR, or leave it in for future debugging, defaulted to
false. Running withTORCHDYNAMO_VERBOSE=1 TORCHINDUCTOR_AUTOGRAD_CACHE=0yielded:Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install)pytest testspytest tests -m integration_testUX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example