@@ -242,7 +242,7 @@ def optim_step(param) -> None:
242242 p .register_post_accumulate_grad_hook (optim_step )
243243
244244
245- BYTE_TO_GIB = 1024 ** 3
245+ _BYTES_IN_GIB = 1024 ** 3
246246
247247
248248def get_memory_stats (device : torch .device , reset_stats : bool = True ) -> dict :
@@ -267,19 +267,19 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
267267 raise ValueError ("Logging memory stats is not supported on CPU devices" )
268268
269269 if device .type == "mps" :
270- peak_memory_active = torch .mps .current_allocated_memory () / BYTE_TO_GIB
271- peak_memory_alloc = torch .mps .driver_allocated_memory () / BYTE_TO_GIB
270+ peak_memory_active = torch .mps .current_allocated_memory () / _BYTES_IN_GIB
271+ peak_memory_alloc = torch .mps .driver_allocated_memory () / _BYTES_IN_GIB
272272 memory_stats = {
273273 "peak_memory_active" : peak_memory_active ,
274274 "peak_memory_alloc" : peak_memory_alloc ,
275275 }
276276 else :
277277 torch_device = get_torch_device_namespace ()
278278 peak_memory_active = (
279- torch_device .memory_stats ().get ("active_bytes.all.peak" , 0 ) / BYTE_TO_GIB
279+ torch_device .memory_stats ().get ("active_bytes.all.peak" , 0 ) / _BYTES_IN_GIB
280280 )
281- peak_memory_alloc = torch_device .max_memory_allocated (device ) / BYTE_TO_GIB
282- peak_memory_reserved = torch_device .max_memory_reserved (device ) / BYTE_TO_GIB
281+ peak_memory_alloc = torch_device .max_memory_allocated (device ) / _BYTES_IN_GIB
282+ peak_memory_reserved = torch_device .max_memory_reserved (device ) / _BYTES_IN_GIB
283283 memory_stats = {
284284 "peak_memory_active" : peak_memory_active ,
285285 "peak_memory_alloc" : peak_memory_alloc ,
0 commit comments