-
Notifications
You must be signed in to change notification settings - Fork 686
Open
Labels
discussionStart a discussionStart a discussion
Description
We just hit OOM, revealing that by default torchtune does not use torch.compile and that it does not use fused linear cross entropy yet...
I found the following report from 2024:
- https://www.reddit.com/r/LocalLLaMA/comments/1di0fhv/torchtune_vs_axolotl_vs_unsloth_trainer/
- https://wandb.ai/augmxnt/train-bench/reports/Trainer-performance-comparison-torchtune-vs-axolotl-vs-Unsloth---Vmlldzo4MzU3NTAx
Are there any plans to make torchtune excellent for peak GPU memory usage and practical OOM handling?
For fine-tuning on long chain-of-thought's, OOM coming from some long not-filtered-out examples is an issue.
Could torchtune somehow make it OOM-crash-safe out-of-the-box? E.g. robustly skip batches in runtime if OOM occurs once in a while? Or aggressively cpu-offload some huge activation tensors, when OOM hits and re-try after off-loading...
File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 919, in train
[rank6]: current_loss = self._loss_step(batch) * current_num_tokens
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 822, in _loss_step
[rank6]: loss = self._loss_fn(outputs, labels)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank6]: return self._call_impl(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank6]: return forward_call(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 136, in forward
[rank6]: total_loss += self.compute_cross_entropy(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 105, in compute_cross
_entropy
[rank6]: return F.cross_entropy(
[rank6]: ^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/functional.py", line 3494, in cross_entropy
[rank6]: return torch._C._nn.cross_entropy_loss(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.15 GiB. GPU 6 has a total capacity of 79.10 GiB of which 3.79 GiB is
free. Including non-PyTorch memory, this process has 75.30 GiB memory in use. Of the allocated memory 58.28 GiB is allocated by PyTorch, and 1
5.06 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_se
gments:True to avoid fragmentation. See documentation for Memory Management
Metadata
Metadata
Assignees
Labels
discussionStart a discussionStart a discussion