Skip to content

Commit ac9fd08

Browse files
dimapihtargenquan9
authored andcommitted
add support for parallel ckpt removal (NVIDIA-NeMo#15073)
* add support for parallel ckpt removal Signed-off-by: dimapihtar <dpihtar@gmail.com> * Apply isort and black reformatting Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> --------- Signed-off-by: dimapihtar <dpihtar@gmail.com> Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com> Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com> Signed-off-by: genquan9 <genquan@google.com>
1 parent 8de0007 commit ac9fd08

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

nemo/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import re
1717
import shutil
18+
import threading
1819
from datetime import timedelta
1920
from pathlib import Path
2021
from typing import Any, Dict, Iterable, List, Literal, Optional, Union
@@ -707,7 +708,16 @@ def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str
707708
# if anything goes wrong during removal, we should be able to detect that data is incomplete.
708709
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
709710
try:
710-
super()._remove_checkpoint(trainer, filepath)
711+
if self.async_save:
712+
threading.Thread(
713+
target=super()._remove_checkpoint,
714+
args=(
715+
trainer,
716+
filepath,
717+
),
718+
).start()
719+
else:
720+
super()._remove_checkpoint(trainer, filepath)
711721
except Exception as e:
712722
logging.warning(
713723
f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}'
@@ -718,7 +728,16 @@ def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str
718728

719729
filepath = self._ema_format_filepath(filepath)
720730
try:
721-
super()._remove_checkpoint(trainer, filepath)
731+
if self.async_save:
732+
threading.Thread(
733+
target=super()._remove_checkpoint,
734+
args=(
735+
trainer,
736+
filepath,
737+
),
738+
).start()
739+
else:
740+
super()._remove_checkpoint(trainer, filepath)
722741
except Exception as e:
723742
logging.warning(
724743
f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}'

0 commit comments

Comments
 (0)