-
Notifications
You must be signed in to change notification settings - Fork 689
Streaming offloading in (q)lora single device #1443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
c4f0366
f1178c7
c937396
9eea976
77c6488
b455efd
af3d22d
b4e2269
b35fa29
206cc88
8c91a32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,6 +83,33 @@ and in most cases training can slow-down quite a bit as a result of this activat | |
| To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag | ||
| in any of our recipes, e.g. ``enable_activation_checkpointing=True``. | ||
|
|
||
| .. _glossary_act_off: | ||
|
|
||
| Activation Offloading | ||
| --------------------- | ||
|
|
||
| *What's going on here?* | ||
|
|
||
| You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory | ||
| efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing | ||
| them back when needed in the backward pass. | ||
|
|
||
| See `PyTorch autograd hook tutorial <https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu>`_ | ||
| for more details about how this is implemented through saved_tensors_hooks. | ||
|
|
||
| This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained. | ||
| However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime | ||
| and resources to move Tensors from GPU to CPU and back. The implementation in torchtune uses multiple CUDA streams | ||
| in order to overlap the extra communication with the computation to hide the extra runtime. As the communication | ||
| workload is variable depending on the number and size of tensors being offloaded, it is common to not offload every | ||
| single activation. In fact, once can use offloading in conjunction with activations checkpointing, where all | ||
| activations will either be recomputed later in the backward or brought back from the CPU. | ||
|
|
||
| *Sounds great! How do I use it?* | ||
|
|
||
| To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag | ||
| in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. | ||
|
||
|
|
||
| .. _glossary_grad_accm: | ||
|
|
||
| Gradient Accumulation | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ dependencies = [ | |
| "numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed | ||
| "tqdm", | ||
| "omegaconf", | ||
| "psutil", | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: psutil does not pull in other deps! |
||
|
|
||
| ] | ||
| dynamic = ["version"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import contextlib | ||
| import sys | ||
| import time | ||
|
|
||
|
|
@@ -30,8 +31,12 @@ | |
| validate_missing_and_unexpected_for_lora, | ||
| ) | ||
| from torchtune.recipe_interfaces import FTRecipeInterface | ||
| from torchtune.training import DummyProfiler, PROFILER_KEY | ||
|
|
||
| from torchtune.training import ( | ||
| DummyProfiler, | ||
| NoOpManager, | ||
| OffloadActivations, | ||
| PROFILER_KEY, | ||
| ) | ||
| from tqdm import tqdm | ||
|
|
||
| log = utils.get_logger("DEBUG") | ||
|
|
@@ -43,13 +48,22 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): | |
| for single GPU training. Training on CPU is not supported. | ||
|
|
||
| Features: | ||
| - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` | ||
| - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` | ||
felipemello1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| flag. Activation checkpointing helps reduce the memory footprint since we no longer keep | ||
| activations in memory and instead recompute them during the backward pass. This is especially | ||
| helpful for larger batch sizes when you're memory constrained. But these savings in memory | ||
| come at the cost of training performance. In most cases training can slow-down quite a bit as | ||
| a result of this activation recomputation. | ||
|
|
||
| - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` | ||
| flag. Activation offloading is a technique similar to activations checkpointing that helps | ||
| reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations | ||
| checkpointing drops the activation in the forward to recompute it later in the backward, | ||
| activations offloading will drop the activation in the forward to the CPU and bring it | ||
| back during the backward pass. As always, there is a tradeoff--these savings in memory can | ||
| come at the cost of training performance and CPU resources. Activation offloading | ||
| can be used in conjunction with activation checkpointing. | ||
|
|
||
| - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` | ||
| flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In | ||
| most cases this should halve the memory footprint of full precision (fp32) training, without | ||
|
|
@@ -222,6 +236,8 @@ def setup(self, cfg: DictConfig) -> None: | |
| self._model = self._setup_model( | ||
| cfg_model=cfg.model, | ||
| enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||
| enable_activation_offloading=cfg.get("enable_activation_offloading", False), | ||
| offload_with_streams=cfg.get("offload_with_streams", False), | ||
|
||
| compile_model=cfg.compile, | ||
| base_model_state_dict=checkpoint_dict[training.MODEL_KEY], | ||
| lora_weights_state_dict=( | ||
|
|
@@ -367,6 +383,8 @@ def _setup_model( | |
| self, | ||
| cfg_model: DictConfig, | ||
| enable_activation_checkpointing: bool, | ||
| enable_activation_offloading: bool, | ||
| offload_with_streams: bool, | ||
| compile_model: bool, | ||
| base_model_state_dict: Dict[str, Any], | ||
| lora_weights_state_dict: Optional[Dict[str, Any]] = None, | ||
|
|
@@ -420,6 +438,22 @@ def _setup_model( | |
| self.adapter_params.items(), dtype=self._dtype | ||
| ) | ||
|
|
||
| self.activations_handling_ctx = contextlib.nullcontext() | ||
| if enable_activation_offloading: | ||
| self.activations_handling_ctx = OffloadActivations( | ||
| use_streams=offload_with_streams | ||
| ) | ||
|
|
||
| # Below is our hack to disable offloading the last output Linear in every | ||
| # step, as the cost for offloading the activation and then soon after bringing | ||
| # it back is expensive. Moreover, due to heuristics in our streaming API, | ||
| # we actually use more memory if we offload it as it interferes with chunkedCE. | ||
| noop_ctx = NoOpManager() | ||
| model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | ||
| model.output.register_forward_hook( | ||
| lambda *args: noop_ctx.__exit__(), always_call=True | ||
|
||
| ) | ||
|
|
||
| log.info(f"Model is initialized with precision {self._dtype}.") | ||
|
|
||
| if self._device.type == "cuda": | ||
|
|
@@ -576,7 +610,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| input_pos = batch.get("input_pos", None) # shape [b, s] | ||
|
|
||
| # run model | ||
| logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
| with self.activations_handling_ctx: | ||
| logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
|
|
||
| # Shift labels to compute loss | ||
| # equivalent to doing labels[..., 1:] and logits[..., :-1, :] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I remember, this only works if activation_checkpointing is True. Is that still right? If so, we should probably update this doc and add to the recipes to raise and error or set AC=True automatically.
Another option, which i would prefer, is to investigate allowing offloading without AC, since streaming seems promising
nit: I believe you meant "one can use offloading" instead of "once can use offloading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "this" in the first sentence? Activations offloading works when AC is false as well, it's just super slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i remember trying to use only offloading, with AC=False, and it broke. Maybe its not the case anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it shouldn't break! It should just be reaaaaally slow