-
Notifications
You must be signed in to change notification settings - Fork 683
Add feature ligerceloss #2741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add feature ligerceloss #2741
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2741
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Cancelled JobsAs of commit a7a5fcb with merge base c7a92e4 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @mananchawla2005! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
CLA filled! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this contribution! Really clean first pass at the problem. I left some comments around testing and questions about how you're handling DTensors. For linting, you can checkout our contributing guide on how to setup precommit hooks.
| # Validate the results are close enough | ||
| assert_expected(fused_loss, standard_loss, rtol=1e-2, atol=1e-2) | ||
|
|
||
| def test_liger_fused_cross_entropy_loss_with_reshape(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For our SFTLoss type we can assume the input is "[bsz, seq_len, emb_dim]", so I don't think we need this second test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have a second test, but it should be a distributed test. Same as the first test but with 4 gpus required and FSDP size 2 and TP size 2. If you need help on how to initialize the model that way I can give you the code.
|
|
||
|
|
||
| class TestLigerFusedCrossEntropyLoss: | ||
| def test_liger_fused_cross_entropy_loss(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this test requires cuda you should add the "@gpu_test(gpu_count=1)" decorator from "from tests.test_utils import gpu_test". Along with testing the loss value, I think it would be good to test a single forward and backward pass with opt step to ensure all the gradients are propagating back correctly too. You can use "fixed_init_model" (also from test_utils) as well to make it easier to initialize the model the same way each time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to add @pytest.mark.parametrize("compile", [False, True]) to the test and pass in compile as an argument on whether to call apply_compile_strategy on the loss
| orig_w = self.linear_projection.weight | ||
| if isinstance(orig_w, DTensor): | ||
| mesh, placements = orig_w.device_mesh, orig_w.placements | ||
| w = orig_w.full_tensor().detach().clone().requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me more about what you're doing here? Does the liger loss require you to detach the weight? Why detach it only to manually register that gradients get reapplied? Also, I don't think you'd want to register a hook every forward pass.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2741 +/- ##
==========================================
+ Coverage 60.64% 62.74% +2.09%
==========================================
Files 428 431 +3
Lines 26091 26479 +388
==========================================
+ Hits 15823 16613 +790
+ Misses 10268 9866 -402 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@pbontrager Hey thanks for an indepth review! I have tried to resolve most of the issues that were coming earlier however for the distributed test would love if you can help a little with the code as well as testing it cause I dont have a distributed gpu setup. |
| # self.forward = torch.compile( | ||
| # self.forward, *args, **kwargs | ||
| # ) | ||
| return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to compile a liger kernel at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey it was added in response to @pbontrager
I think it would be good to add @pytest.mark.parametrize("compile", [False, True]) to the test and pass in compile as an argument on whether to call apply_compile_strategy on the loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, ok, now I see 🙂
So the problem is that the kernel is automatically compiled.
It was obvious from the README of their repo - there was no mentioning that a user needs to call torch.compile manually, which leaves only one option. Yes, Liger provides custom optimized triton kernels, but without compilation they won't work.
So, after digging a bit of their codebase, here is how it works:
LigerFusedLinearCrossEntropyFunctionhas custom forward and backward implementations. (Let's focus on forward variant for now.)- Inside it, a fused_linear_cross_entropy_forward is called ...
- ... which calls a triton kernel, that has triton.jit wrapper.
Of course, you can control even that with
with torch._dynamo.disable():and a flag that is disabled by default and enabled in apply_compile_strategy, but since Triton kernel code is not valid Python code for direct execution on a GPU without a compilation, there is no much sense in it.
Perhaps maybe control compilation of service code inside a loss code around this kernel, but I believe it won't make that much difference 🤔.
In other words, since it was never intended to use liger without a compilation, perhaps just skip this method without any warnings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the warning and added a doc string that this is JIT compiled
| try: | ||
| import liger_kernel.ops.fused_linear_cross_entropy | ||
|
|
||
| self.fused_linear_ce = liger_kernel.ops.fused_linear_cross_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dummy question: why did you deciced to go this route instead of a loss class as described in README: https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#3-compose-your-own-model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, you are right to point out that! I chose the lower-level ops approach for bias handling as one under transformers don't have a option for bias parameter and during distributed training we need to handle DTensor. If thats not required I can replace it with the one in readme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it looks like forward method of the class also accepts bias, which by default is None.
Basically, there is nothing wrong with your approach, just the one with the class looks, at least to me, slightly cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the class version was just a thin wrapper around LigerFusedLinearCrossEntropyFunction, and our loss class is also a thin wrapper around the same functionality, it feels more right to me that we just directly call LigerFusedLinearCrossEntropyFunction and operate at the same abstraction level as the nn.Module you linked to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, fair enough.
Linear loss calls F.cross_entropy and this one calls a function - a lil bit of uniformity.
|
I think it makes more sense to reshape I've tested a distributed LoRA finetune of llama3.1-8B with those changes, and it seems to work fine, amount of reserved memory was reduced and the difference in loss was minimal. |
|
@intervitens Hey thanks for checking out the distributed training! Glad that it works well. I have incorporated your changes. It would also be very helpful if you can provide an initial starting for adding the distributed test. |
| batch_size, seq_len, emb_dim = hidden_states.shape | ||
| hidden_states = hidden_states.reshape( | ||
| -1, emb_dim | ||
| ) # [batch_size*seq_len, emb_dim] | ||
| targets = targets.reshape(-1) # [batch_size*seq_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like since you don't reuse B, T, C values anywhere, it could be done simpler:
| batch_size, seq_len, emb_dim = hidden_states.shape | |
| hidden_states = hidden_states.reshape( | |
| -1, emb_dim | |
| ) # [batch_size*seq_len, emb_dim] | |
| targets = targets.reshape(-1) # [batch_size*seq_len] | |
| hidden_states = hidden_states.flatten(0, 1) # (batch_size*seq_len, hidden_size) | |
| targets = targets.flatten() # (batch_size*seq_len) |
| ) | ||
| if total_elements == 0: | ||
| return loss | ||
| return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, basically, return loss regardless? 🙂
pbontrager
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for doing all the updates! I think it's close now. I'm going to help with the unit tests and then once you have a chance to make any changes based on Andrei's comments, we should be good to land.
pyproject.toml
Outdated
| "wandb", | ||
| "expecttest", | ||
| # Triton: | ||
| "triton>=2.3.1 ; platform_system != 'Windows'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this dependency explicitly needed? Pytorch already includes triton I believe.
| b.register_hook(_scatter_b) | ||
| self._b_hook_registered = True | ||
|
|
||
| loss, _ = self.fused_linear_ce.LigerFusedLinearCrossEntropyFunction.apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: couldn't self.fused_linear_ce = LigerFusedLinearCrossEntropyFunction in your init and then here you'd just have self.fused_linear_ce.apply(...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more nit: could we simplify the arguments list, since most of the values that are provided are actually equal to the default ones?
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
@amp_custom_fwd
def forward(
ctx,
_input,
weight,
target,
bias=None,
ce_weight=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
return_z_loss: bool = False,
):| if isinstance(w, DTensor): | ||
| mesh, placements = w.device_mesh, w.placements | ||
| w = w.full_tensor() | ||
| if not hasattr(self, "_w_hook_registered"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "full_tensor" handles gradient placement and we don't need to do this ourselves link. I can test removing this though.
| try: | ||
| import liger_kernel.ops.fused_linear_cross_entropy | ||
|
|
||
| self.fused_linear_ce = liger_kernel.ops.fused_linear_cross_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the class version was just a thin wrapper around LigerFusedLinearCrossEntropyFunction, and our loss class is also a thin wrapper around the same functionality, it feels more right to me that we just directly call LigerFusedLinearCrossEntropyFunction and operate at the same abstraction level as the nn.Module you linked to.
| from torchtune.training.seed import set_seed | ||
|
|
||
|
|
||
| @gpu_test(gpu_count=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me edit and push some changes to these tests. I can run them myself then for you.
|
@pbontrager I have simplifed distributed handling as well done the refactoring as per suggested changes. Would be greatful if you can help me add tests for distributed. |
|
Hey any updates? |
|
@mananchawla2005 I've gotten the unit tests working to test FSDP + TP for a single training step with the Liger loss. I haven't pushed the changes to your PR yet because the DTensor backward hook doesn't seem to work correctly so I'm trying to fix that and get the test to pass. I should be able to get back to this and get something to you by the end of this week. |
|
I pushed the tests with some changes here but the tests aren't passing in the distributed case. I've tried playing around with getting the full tensor in a full TP setting and full FSDP setting (the default test is a mix of both) but I'm still getting numerical differences. @ebsmothers do you have any ideas here? |
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mananchawla2005 for getting this most of the way! I have a few more small comments based on @pbontrager's changes, then this should be good to go. @pbontrager please make sure to test on post-6/4 nightlies given the changes in pytorch/pytorch#154704. Stamping to unblock
|
|
||
| dist.destroy_process_group() | ||
|
|
||
| @gpu_test(gpu_count=WORLD_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just say 4 explicitly here. I like to be able to look directly at the test and see how many GPUs it needs
|
|
||
| # Verify: | ||
| # 1. Validate the results are close enough | ||
| assert_expected(fused_loss, standard_loss, rtol=1e-2, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this as close as we can get?
| """Memory efficient Cross-entropy loss that uses fused CUDA kernels to compute the loss. | ||
| Combines the linear projection with the cross-entropy calculation for better performance | ||
| and memory efficiency. This is an approximation of CrossEntropyLoss and may have small | ||
| numerical differences compared to the standard implementation. This is a wrapper around | ||
| `LigerFusedLinearCrossEntropyFunction` from the `liger_kernel` package. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a reference to the Liger paper or repo here?
| plan = { | ||
| "output": ColwiseParallel( | ||
| input_layouts=Replicate(), output_layouts=Replicate() | ||
| ) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually a nontrivial application of TP? (Like should we at least have some shard for layer or something?)
| raise RuntimeError("Must call set_model_output() before forward()") | ||
|
|
||
| if isinstance(hidden_states, DTensor): | ||
| hidden_states = hidden_states.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the compute isn't replicated across TP workers unnecessarily, could the hidden states and targets both be sharded on the sequence dimension, losses calculated per token, then correctly reduced after? (Basically to make better uses of the devices in a TP group).
Could be a follow-up PR, not MVP for liger. Just raising while I work on adding loss parallel for the standard LinearCrossEntropyLoss https://github.com/pytorch/torchtune/pull/2782/files
(EDIT: I can probably look to add this in future)
|
Any plans for contributing this great fused linear cross entropy loss impl directly into PyTorch core?
As probably the same question is relevant for torchtitan, so would be best to have it upstreamed? (although, maybe a copy of Liger Triton (or CUDA?) kernel would probably need to be upstreamed, this PR is a wrapper around Liger package) And this way, it would also provide a good baseline for the Inductor codegen I think it's been now quite some time that fused linear cross entropy has proved it's actually useful for warranting its inclusion in core (especially in the meanwhile Inductor underperforms codegen for this op) |
|
@vadimkantorov actually this is something that @ngimel has brought up as well, she may have a better idea on the latest status of things here.
|
|
I wonder if both triton-code and triton-produced cubin / ptx can be included in core distribution for populating local artifact cache . Like so we can both preserve triton's hackability and zero wait time for eager mode... |
|
Hey @vadimkantorov,
Wanted to let you know that a fused linear cross entropy in core is on our roadmap, we plan to work on it in the next month or so |
|
Any plans for a road forward for inclusion in core directly triton codes and torch.compile-d codes? |
Not sure if this is what you're referring to, but I believe there is already infra to include triton code in core https://github.com/pytorch/pytorch/blob/a69e27ca5ad4287add73972ef1b34b469e3c7d23/torch/cuda/__init__.py#L1669-L1702 that is used for a very small number of ops As to the triton-produced cubin / ptx I'm not sure how that would look like but I'll investigate when we look at upstreaming the fused linear cross entropy |
|
Exactly, I wonder what is needed for Triton code to be used for more ops in core. If it's cold start time for eager, then shipping pre-cached / pre-generated / pre-compiled ptx could be the solution (and still allow hackability if one wants to copy-paste and modify the triton code) |
|
Precompiled triton is not really an option, triton is doing too much specialization, so startup time/unexpected reompiles will always be a problem for using triton in eager. |
|
Is it possible to somehow precompile / pregenerate from triton some version of ptx which at least would run on all relevant hardwares? Or are some new features in triton needed to control how it specializes / force it to specialize less? (I guess ideally one would like to have some sort of okay AOT ptx generation from triton which would at least run on many hardwares - even if not reaching the best possible perf. The PyTorch could ship ptx for a few popular important hardwares, and for others force the user to run triton compilation in their own machine if they want extra perf. This would allow users to write and maintain triton core where now they have to rewrite in cuda ) Or is it possible to lower triton into the cuda/c++ code? |
|
@mikaylagawarecki If there are any public issues/PRs that we could follow to track progress for a fused linear loss in Core, that would be great. I don't see it currently in the 2.9.0 milestones. |
@mikaylagawarecki If maybe vllm's impls could be upstreamed, vllm could transition to use more of core modules. Hopefully, vllm/sglang can transition to be using these new pytorch core modules (I wonder if this should inform the design/numerics in any ways) and thus reducing fragmentation... Yet another recent impl of fused linear cross entropy (indicating fragmentation): https://github.com/volcengine/verl/blob/main/verl/utils/kernel/kernels.py Also, a tutorial is needed for implementing forward+backward losses fusing Linear + some cross-entropy-like loss (e.g. important case is GRPO) |
|
Also curious if the pattern of fusing chunked Linear and some loss computation can be also implemented as a more generic/compilable higher-order op in PyTorch core. Seems the same pattern arises / is needed for GRPO loss and reducing peak VRAM for super-long reasoning traces (and maybe can be done with other loss variants as well)... |
PR: Add LigerFusedCrossEntropyLoss
Context
What is the purpose of this PR? Is it to
Closes #2692
Changelog
LigerFusedCrossEntropyLossclass that provides memory-efficient cross entropy loss using fused CUDA kernels usingliger-kernelsTest plan
F.cross_entropyUX
Example usage in docstring:
The implementation provides better performance and memory efficiency compared to the chunked
LinearCrossEntropyLossby:All tests verify numerical correctness against PyTorch's native cross entropy within expected tolerances.
Hi PyTorch Team,
This is my first PR to a machine learning project, and I’ve tried to ensure the code is correct and well-structured. I’ve implemented the functionality for LigerCeLoss and included a test case that verifies its behavior with both masked and reshaped inputs.
Due to hardware limitations, I wasn't able to fully validate the distributed functionality or run the entire test suite across multiple GPUs. However, I’ve implemented support for DTensor-based weights, hidden states, and targets, and included the logic for gather → fuse → scatter to handle gradient computation for sharded weights in distributed settings.
Looking forward to your feedback! @pbontrager @joecummings