diff --git a/docs/source/tutorials/memory_optimizations.rst b/docs/source/tutorials/memory_optimizations.rst index aab23d2a0e..69295e7264 100644 --- a/docs/source/tutorials/memory_optimizations.rst +++ b/docs/source/tutorials/memory_optimizations.rst @@ -83,6 +83,35 @@ and in most cases training can slow-down quite a bit as a result of this activat To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag in any of our recipes, e.g. ``enable_activation_checkpointing=True``. +.. _glossary_act_off: + +Activation Offloading +--------------------- + +*What's going on here?* + +You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory +efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing +them back when needed in the backward pass. + +See `PyTorch autograd hook tutorial `_ +for more details about how this is implemented through saved_tensors_hooks. + +This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained. +However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime +and resources to move Tensors from GPU to CPU and back. The implementation in torchtune has the ``offload_with_streams`` +option to use multiple CUDA streams in order to overlap the extra communication with the computation to hide the extra +runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, it is +common to not offload every single activation. In fact, once can use offloading in conjunction with activations +checkpointing, where all activations will either be recomputed later in the backward or brought back from the CPU. + +*Sounds great! How do I use it?* + +To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag +in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. To allow +usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.dev20240907 and +specify ``offload_with_streams=True``. + .. _glossary_grad_accm: Gradient Accumulation diff --git a/pyproject.toml b/pyproject.toml index 322f4238c9..58a0e4ce33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed "tqdm", "omegaconf", + "psutil", ] dynamic = ["version"] diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 854010aee5..2f57caa0b3 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -73,6 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 7c635d9ff0..a0207a4ca8 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -73,6 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 74ff398ab3..c26314a98f 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -70,6 +70,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 7099112988..aa667f7c16 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -70,6 +70,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index 0a25bae846..805f21b5bb 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -72,6 +72,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 3d63a548d1..17c641d241 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -72,6 +72,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 5f0995c911..ec823ebd24 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -80,7 +80,9 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index 3ace543928..18cfb54add 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -80,7 +80,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index bf710df03c..b6170c507e 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -79,7 +79,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 7a7c64a4a0..3fd0e2e745 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -79,7 +79,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index c50fedc243..3adcf11330 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -78,7 +78,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: True # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 4121a4e51f..602cd05fb6 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -82,7 +82,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 7a4baf5ca7..8b8f4c58a8 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -81,7 +81,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Offloading enable_activation_checkpointing: True +enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index f42d2d5fc3..e21d7d2669 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -76,6 +76,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 29929d150f..3e504b98ff 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -77,6 +77,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index b5783d6f7e..775a47e0ee 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -71,6 +71,9 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index b5e031745f..a41f91b5d4 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -71,6 +71,9 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision dtype: bf16 # Logging diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index fff48bb626..bb364e9ef9 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -79,7 +79,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index c55a485dd3..dccd577742 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -77,7 +77,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index 0425850cd1..301d85747e 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -81,7 +81,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Offloading enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 0862675a77..c6cfeefe18 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time @@ -30,8 +31,12 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY - +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -43,13 +48,25 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): for single GPU training. Training on CPU is not supported. Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 + or later and will be enabled by default if an acceptable torch version is found. Activation + offloading can be used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -101,6 +118,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. """ @@ -138,6 +156,13 @@ def __init__(self, cfg: DictConfig) -> None: self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading and self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be enabled for training on CUDA" + ) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -222,6 +247,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -367,6 +393,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, @@ -420,6 +447,23 @@ def _setup_model( self.adapter_params.items(), dtype=self._dtype ) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + log.info(f"Model is initialized with precision {self._dtype}.") if self._device.type == "cuda": @@ -576,7 +620,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: input_pos = batch.get("input_pos", None) # shape [b, s] # run model - logits = self._model(tokens, mask=mask, input_pos=input_pos) + with self.activations_handling_ctx: + logits = self._model(tokens, mask=mask, input_pos=input_pos) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 241326cfce..7a1c4651cc 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torchtune.training._activation_offloading import NoOpManager, OffloadActivations from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( contains_fsdp, @@ -122,4 +123,6 @@ "setup_torch_profiler", "compile_loss", "compile_model", + "NoOpManager", + "OffloadActivations", ] diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py new file mode 100644 index 0000000000..5156281aa8 --- /dev/null +++ b/torchtune/training/_activation_offloading.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional +from warnings import warn + +import psutil +import torch +import torchao +from torch.autograd.graph import saved_tensors_hooks +from torchao.dtypes.nf4tensor import NF4Tensor + + +class OffloadActivations(saved_tensors_hooks): + """Context manager under which activation tensors created in the forward pass will be offloaded. + + Enable the memory efficiency technique of activation offloading, where activations bigger than + min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward. + This is in contrast to maintaining the activation on GPU VRAM throughout the program. + + This manager contains the option of using one additional CUDA stream to handle the communication + between CUDA and CPU, which is intended to overlap with the default computation stream to improve + runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between + runtime vs memory usage. + + Args: + use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned + memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly + but is a limited resource. Default: True. + + use_streams (Optional[bool]): Whether or not to use streams for performance optimization where + the communications get overlapped with the computation. Requires a torch build + after torch-2.5.0.dev20240907. Default: True if a later torch build is found, else False. + + max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of + consecutive activations to keep alive during the forward pass. This number must be at + least 1. Keeping alive more activations will potentially allow more overlap between the + communication and compute streams at the cost of increasing memory usage. Keeping alive + fewer activations will conserve memory, but may cause poor overlap between the streams, + increasing runtime. Default: 5. + + min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify + for offloading. If the tensor is too small, we do not want to waste bandwidth and resources + moving it to CPU and back. Default: 1024 bytes. + + Raises: + ValueError: if max_fwd_stash_size is not at least 1. + RuntimeError: if use_streams but torch installation is earlier than torch-2.5.0.dev20240907 + + Example: + >>> with OffloadActivations(): + >>> logits = model(inputs) + >>> loss = ... + >>> loss.backward() + """ + + def __init__( + self, + use_pin_memory: bool = True, + use_streams: Optional[bool] = None, + max_fwd_stash_size: int = 5, + min_offload_size: int = 1024, + ) -> None: + if use_streams is None: + # Default to True if an acceptable torch is installed (later nightly/version or from source) + self.use_streams = torch.__version__ >= "2.5.0.dev20240907" + else: + self.use_streams = use_streams + + self.min_tensor_size_bytes = ( + min_offload_size # we don't want to bother with small tensors + ) + self.tracker = ( + {} + ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where + self.tensor_id: int = 0 + self.is_first_forward_call = True + self.is_first_backward_call = True + self.is_first_forward_pass = True + + # managing cpu memory + self.use_pin_memory: bool = use_pin_memory + self.virtual_memory_safe_pct = ( + 60 # we should not exceed this percentage of memory + ) + + self.s0 = torch.cuda.default_stream() # comp stream + + # for streaming + if self.use_streams: + if torch.__version__ < "2.5.0.dev20240907": + raise RuntimeError( + "OffloadActivations with use_streams=True requires PyTorch 2.5.0.dev20240907 or later." + ) + self.s1 = torch.cuda.Stream() # comms stream + self.fwd_stash = {} # tensor_id => (activation, ev1) + if max_fwd_stash_size < 1: + raise ValueError( + f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}" + ) + self.max_fwd_stash_size = max_fwd_stash_size + self.bwd_tensor_stash = {} # tensor_id => activation + self.bwd_ev_stash = {} # tensor_id => ev0 + self.curr_graph_id = None + self.curr_autograd_node = None + + # -------- platform util functions -------- # + def verify_sufficient_virtual_memory(): + curr_pct = get_cpu_ram_pct() + if curr_pct > self.virtual_memory_safe_pct: + warn( + f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used" + ) + + def get_cpu_ram_pct() -> float: + # get the percentage of memory used by the system + return psutil.virtual_memory().percent + + def get_tensor_id() -> int: + # create a unique id for each tensor we are managing + self.tensor_id += 1 + return self.tensor_id + + def get_num_bytes_tensor(x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return ( + x.element_size() * x.nelement() + ) # x.element_size() * x._base_storage().nbytes() + + # -------- core pack / unpack work -------- # + def pack_tensor(activation: torch.Tensor) -> int: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward_call: + assert ( + len(self.tracker) == 0 + ), "backward pass should have cleared tracker of all tensors" + + # set training phase trackers + self.is_first_forward_call = False + self.is_first_backward_call = True + + # query for basic tensor info + num_bytes = get_num_bytes_tensor(activation) + tensor_id = get_tensor_id() + + # only offload hefty bois + if num_bytes >= self.min_tensor_size_bytes: + if self.use_streams: + # First, sync back and dereference previously offloaded tensors + # as the offloading should be done sufficiently long ago. + for id in [k for k in self.fwd_stash.keys()]: + if id <= tensor_id - self.max_fwd_stash_size: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + else: + break + + # Sync in, offload, and add an event to sync back later + self.s1.wait_stream(self.s0) + + stream = self.s1 if self.use_streams else self.s0 + with torch.cuda.stream(stream): + try: + cpu_tensor = torch.empty_like( + activation, pin_memory=self.use_pin_memory, device="cpu" + ) + except NotImplementedError as e: + if ( + isinstance(activation, NF4Tensor) + and torchao.__version__ < "0.6.0.dev20240917" + ): + raise RuntimeError( + "Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later" + ) from e + raise e + cpu_tensor.copy_(activation, non_blocking=True) + self.tracker[tensor_id] = ( + cpu_tensor, + True, + ) # True = (in future) modified + + if self.use_streams: + event = self.s1.record_event() + + # Stash to keep activation alive til s1 is done + self.fwd_stash[tensor_id] = (activation, event) + else: + self.tracker[tensor_id] = ( + activation, + False, + ) # False = not modified, tensor is as is + + return tensor_id + + def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + assert ( + unpack_tensor_id in self.tracker + ), f"untracked tensor with id {unpack_tensor_id}" + + maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) + maybe_gpu_tensor = gpu_tensor + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_gpu_tensor + + def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + self.curr_graph_id = torch._C._current_graph_task_id() + + def wait_and_del_remaining_references() -> None: + for id in [k for k in self.bwd_tensor_stash.keys()]: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + # Register a callback to the end of autograd to clean everything up + torch.autograd.variable.Variable._execution_engine.queue_callback( + wait_and_del_remaining_references + ) + + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + assert ( + unpack_tensor_id in self.tracker + ), f"untracked tensor with id {unpack_tensor_id}" + + maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + # Get data on the current autograd node + graph_id = torch._C._current_graph_task_id() + node = torch._C._current_autograd_node() + prev_node_ids = [] + + # If we're on a new node, mark prev node's tensors to be freed later + if graph_id == self.curr_graph_id and self.curr_autograd_node != node: + self.curr_autograd_node = node + prev_node_ids = [id for id in self.bwd_tensor_stash.keys()] + + brought_back_from_cpu = True + if unpack_tensor_id in self.fwd_stash: + maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] + brought_back_from_cpu = False + else: + # Kick off the process to bring tensors back + with torch.cuda.stream(self.s1): + gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) + maybe_gpu_tensor = gpu_tensor + + # Tell comp stream to wait for the info to be loaded before executing + self.s0.wait_stream(self.s1) + + # Stash the tensor to keep memory alive until compute stream is complete + self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor + + def hook(outputs, inputs): + # create events for the current node inputs/outputs if they were streamed in + if brought_back_from_cpu: + event = self.s0.record_event() + self.bwd_ev_stash[unpack_tensor_id] = event + + # if there are still things in the fwd_stash, get rid of them as we're in bwd now + for id in [k for k in self.fwd_stash.keys()]: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + + # wait on prev node's events and del those + for id in prev_node_ids: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + return outputs + + node.register_hook(hook) + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_gpu_tensor + + unpack_tensor = ( + unpack_tensor_with_streams + if self.use_streams + else unpack_tensor_single_stream + ) + super().__init__(pack_tensor, unpack_tensor) + + +class NoOpManager(saved_tensors_hooks): + """ + A saved_tensors_hook manager used to disable any other saved_tensors_hook manager + applied before. This relies on the behavior that only the most recently registered + saved_tensors_hook will run. + + One example usage is to opt a local region of code out of activations offloading, + which is usually applied globally to best track state. + """ + + def __init__(self) -> None: + def noop(tensor): + return tensor + + super().__init__(noop, noop)