Skip to content

OOM handling and recovery #2830

@vadimkantorov

Description

@vadimkantorov

We just hit OOM, revealing that by default torchtune does not use torch.compile and that it does not use fused linear cross entropy yet...

I found the following report from 2024:

Are there any plans to make torchtune excellent for peak GPU memory usage and practical OOM handling?

For fine-tuning on long chain-of-thought's, OOM coming from some long not-filtered-out examples is an issue.

Could torchtune somehow make it OOM-crash-safe out-of-the-box? E.g. robustly skip batches in runtime if OOM occurs once in a while? Or aggressively cpu-offload some huge activation tensors, when OOM hits and re-try after off-loading...

File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 919, in train                
[rank6]:     current_loss = self._loss_step(batch) * current_num_tokens                                                                       
[rank6]:                    ^^^^^^^^^^^^^^^^^^^^^^                                                                                            
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 822, in _loss_step           
[rank6]:     loss = self._loss_fn(outputs, labels)                                                                                            
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                            
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl            
[rank6]:     return self._call_impl(*args, **kwargs)                                                                                          
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                          
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl                    
[rank6]:     return forward_call(*args, **kwargs)                                                                                             
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                             
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 136, in forward      
[rank6]:     total_loss += self.compute_cross_entropy(                                                                                        
[rank6]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                        
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 105, in compute_cross
_entropy                                                                                                                                      
[rank6]:     return F.cross_entropy(                                                                                                          
[rank6]:            ^^^^^^^^^^^^^^^^                                                                                                          
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/functional.py", line 3494, in cross_entropy                     
[rank6]:     return torch._C._nn.cross_entropy_loss(                                                                                          
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                          
[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.15 GiB. GPU 6 has a total capacity of 79.10 GiB of which 3.79 GiB is 
free. Including non-PyTorch memory, this process has 75.30 GiB memory in use. Of the allocated memory 58.28 GiB is allocated by PyTorch, and 1
5.06 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_se
gments:True to avoid fragmentation.  See documentation for Memory Management

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions