Skip to content

Commit 1ef06be

Browse files
committed
Make constant private and rename
1 parent 59fbdb4 commit 1ef06be

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torchtune/training/memory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def optim_step(param) -> None:
242242
p.register_post_accumulate_grad_hook(optim_step)
243243

244244

245-
BYTE_TO_GIB = 1024**3
245+
_BYTES_IN_GIB = 1024**3
246246

247247

248248
def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
@@ -267,19 +267,19 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
267267
raise ValueError("Logging memory stats is not supported on CPU devices")
268268

269269
if device.type == "mps":
270-
peak_memory_active = torch.mps.current_allocated_memory() / BYTE_TO_GIB
271-
peak_memory_alloc = torch.mps.driver_allocated_memory() / BYTE_TO_GIB
270+
peak_memory_active = torch.mps.current_allocated_memory() / _BYTES_IN_GIB
271+
peak_memory_alloc = torch.mps.driver_allocated_memory() / _BYTES_IN_GIB
272272
memory_stats = {
273273
"peak_memory_active": peak_memory_active,
274274
"peak_memory_alloc": peak_memory_alloc,
275275
}
276276
else:
277277
torch_device = get_torch_device_namespace()
278278
peak_memory_active = (
279-
torch_device.memory_stats().get("active_bytes.all.peak", 0) / BYTE_TO_GIB
279+
torch_device.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
280280
)
281-
peak_memory_alloc = torch_device.max_memory_allocated(device) / BYTE_TO_GIB
282-
peak_memory_reserved = torch_device.max_memory_reserved(device) / BYTE_TO_GIB
281+
peak_memory_alloc = torch_device.max_memory_allocated(device) / _BYTES_IN_GIB
282+
peak_memory_reserved = torch_device.max_memory_reserved(device) / _BYTES_IN_GIB
283283
memory_stats = {
284284
"peak_memory_active": peak_memory_active,
285285
"peak_memory_alloc": peak_memory_alloc,

0 commit comments

Comments
 (0)