diff --git a/torchtune/training/memory.py b/torchtune/training/memory.py index e5459798f4..19a8b09b90 100644 --- a/torchtune/training/memory.py +++ b/torchtune/training/memory.py @@ -242,6 +242,9 @@ def optim_step(param) -> None: p.register_post_accumulate_grad_hook(optim_step) +_BYTES_IN_GIB = 1024**3 + + def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict: """ Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will @@ -250,7 +253,7 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict: individual sections of training. Args: - device (torch.device): Device to get memory summary for. Only CUDA devices are supported. + device (torch.device): Device to get memory summary for. Supports CUDA and MPS devices. reset_stats (bool): Whether to reset CUDA's peak memory tracking. Returns: @@ -258,25 +261,33 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict: and peak memory reserved. This dict is useful for logging memory stats. Raises: - ValueError: If the passed-in device is not CUDA. + ValueError: If the passed-in device is CPU. """ if device.type == "cpu": raise ValueError("Logging memory stats is not supported on CPU devices") - torch_device = get_torch_device_namespace() - peak_memory_active = torch_device.memory_stats().get("active_bytes.all.peak", 0) / ( - 1024**3 - ) - peak_mem_alloc = torch_device.max_memory_allocated(device) / (1024**3) - peak_mem_reserved = torch_device.max_memory_reserved(device) / (1024**3) - if reset_stats: - torch_device.reset_peak_memory_stats(device) - - memory_stats = { - "peak_memory_active": peak_memory_active, - "peak_memory_alloc": peak_mem_alloc, - "peak_memory_reserved": peak_mem_reserved, - } + if device.type == "mps": + peak_memory_active = torch.mps.current_allocated_memory() / _BYTES_IN_GIB + peak_memory_alloc = torch.mps.driver_allocated_memory() / _BYTES_IN_GIB + memory_stats = { + "peak_memory_active": peak_memory_active, + "peak_memory_alloc": peak_memory_alloc, + } + else: + torch_device = get_torch_device_namespace() + peak_memory_active = ( + torch_device.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB + ) + peak_memory_alloc = torch_device.max_memory_allocated(device) / _BYTES_IN_GIB + peak_memory_reserved = torch_device.max_memory_reserved(device) / _BYTES_IN_GIB + memory_stats = { + "peak_memory_active": peak_memory_active, + "peak_memory_alloc": peak_memory_alloc, + "peak_memory_reserved": peak_memory_reserved, + } + if reset_stats: + torch_device.reset_peak_memory_stats(device) + return memory_stats @@ -288,19 +299,20 @@ def log_memory_stats( ) -> None: """ Logs a dict containing memory stats to the logger. ``stats`` should contain the fields - ``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` as + ``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` (optional) as returned by :func:`torchtune.training.get_memory_stats`. Args: stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory - allocated, and peak memory reserved stats. + allocated, and peak memory reserved (optional) stats. message (str): An optional message to prepend to the log output. Defaults to "Memory stats after model init:" """ device_support = get_device_support() _log.info( - f"{message}" - f"\n\t{device_support.device_name} peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB" - f"\n\t{device_support.device_name} peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB" - f"\n\t{device_support.device_name} peak memory active: {stats['peak_memory_active']:.2f} GiB" + f"{message}\n" + + "\n".join( + f"\t{device_support.device_name} {key.replace('_', ' ')}: {value:.2f} GiB" + for key, value in stats.items() + ) )