Skip to content
Merged
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions torchtune/training/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def optim_step(param) -> None:
p.register_post_accumulate_grad_hook(optim_step)


_BYTES_IN_GIB = 1024**3


def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
"""
Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will
Expand All @@ -250,33 +253,41 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
individual sections of training.

Args:
device (torch.device): Device to get memory summary for. Only CUDA devices are supported.
device (torch.device): Device to get memory summary for. Supports CUDA and MPS devices.
reset_stats (bool): Whether to reset CUDA's peak memory tracking.

Returns:
Dict[str, float]: A dictionary containing the peak memory active, peak memory allocated,
and peak memory reserved. This dict is useful for logging memory stats.

Raises:
ValueError: If the passed-in device is not CUDA.
ValueError: If the passed-in device is CPU.
"""
if device.type == "cpu":
raise ValueError("Logging memory stats is not supported on CPU devices")

torch_device = get_torch_device_namespace()
peak_memory_active = torch_device.memory_stats().get("active_bytes.all.peak", 0) / (
1024**3
)
peak_mem_alloc = torch_device.max_memory_allocated(device) / (1024**3)
peak_mem_reserved = torch_device.max_memory_reserved(device) / (1024**3)
if reset_stats:
torch_device.reset_peak_memory_stats(device)

memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_mem_alloc,
"peak_memory_reserved": peak_mem_reserved,
}
if device.type == "mps":
peak_memory_active = torch.mps.current_allocated_memory() / _BYTES_IN_GIB
peak_memory_alloc = torch.mps.driver_allocated_memory() / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
}
else:
torch_device = get_torch_device_namespace()
peak_memory_active = (
torch_device.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
)
peak_memory_alloc = torch_device.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch_device.max_memory_reserved(device) / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
"peak_memory_reserved": peak_memory_reserved,
}
if reset_stats:
torch_device.reset_peak_memory_stats(device)

return memory_stats


Expand All @@ -288,19 +299,19 @@ def log_memory_stats(
) -> None:
"""
Logs a dict containing memory stats to the logger. ``stats`` should contain the fields
``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` as
``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` (optional) as
returned by :func:`torchtune.training.get_memory_stats`.

Args:
stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory
allocated, and peak memory reserved stats.
allocated, and peak memory reserved (optional) stats.
message (str): An optional message to prepend to the log output.
Defaults to "Memory stats after model init:"
"""
device_support = get_device_support()
_log.info(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's loop over the dictionary items so we log the key and the value. This way, we can omit the peak_memory_reserved if it doesn't exist at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. New output for MPS:

INFO:torchtune.utils._logging:Memory stats after model init:
        CPU peak memory active: 2.45 GiB
        CPU peak memory alloc: 3.10 GiB

f"{message}"
f"\n\t{device_support.device_name} peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB"
f"\n\t{device_support.device_name} peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB"
f"\n\t{device_support.device_name} peak memory reserved: {stats.get('peak_memory_reserved', 0):.2f} GiB"
f"\n\t{device_support.device_name} peak memory active: {stats['peak_memory_active']:.2f} GiB"
)