Skip to content

Commit 5508dfb

Browse files
chtruong814guyueh1
authored andcommitted
Fix "Safely import optional python packages (NVIDIA-NeMo#13936)" (NVIDIA-NeMo#14198)
* Revert "Revert "Safely import optional python packages (NVIDIA-NeMo#13936)" (NVIDIA-NeMo#14197)" This reverts commit 808845b. Signed-off-by: Charlie Truong <[email protected]> * Fix LocalCheckpointCallback safe import Signed-off-by: Charlie Truong <[email protected]> --------- Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: Guyue Huang <[email protected]>
1 parent 05ab161 commit 5508dfb

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

nemo/collections/llm/modelopt/speculative/model_transform.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616

1717
from nemo.collections.llm import GPTModel
1818
from nemo.utils import logging
19-
from nemo.utils.import_utils import safe_import
19+
from nemo.utils.import_utils import UnavailableError, safe_import
2020
from nemo.utils.model_utils import unwrap_model
2121

2222
mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
2323
mtsp, _ = safe_import("modelopt.torch.speculative")
2424

25-
26-
ALGORITHMS = {
27-
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
28-
# more TBD
29-
}
25+
try:
26+
ALGORITHMS = {
27+
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
28+
# more TBD
29+
}
30+
except UnavailableError:
31+
ALGORITHMS = {}
3032

3133

3234
def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> nn.Module:

nemo/lightning/nemo_logger.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
import lightning.pytorch as pl
2323
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
2424
from lightning.pytorch.loggers import Logger, TensorBoardLogger, WandbLogger
25-
from nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback import LocalCheckpointCallback
2625

2726
from nemo.lightning.io.mixin import IOMixin
2827
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
2928
from nemo.utils import logging
3029
from nemo.utils.app_state import AppState
30+
from nemo.utils.import_utils import safe_import
31+
32+
lcp, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback')
3133

3234

3335
@dataclass
@@ -203,7 +205,7 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None):
203205
if ckpt:
204206
_overwrite_i = None
205207
for i, callback in enumerate(trainer.callbacks):
206-
if isinstance(callback, PTLModelCheckpoint) and not isinstance(callback, LocalCheckpointCallback):
208+
if isinstance(callback, PTLModelCheckpoint) and not isinstance(callback, lcp.LocalCheckpointCallback):
207209
logging.warning(
208210
"The Trainer already contains a ModelCheckpoint callback. " "This will be overwritten."
209211
)
@@ -248,7 +250,7 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None):
248250
from nemo.lightning import MegatronStrategy
249251

250252
for callback in trainer.callbacks:
251-
if isinstance(callback, PTLModelCheckpoint) and not isinstance(callback, LocalCheckpointCallback):
253+
if isinstance(callback, PTLModelCheckpoint) and not isinstance(callback, lcp.LocalCheckpointCallback):
252254
if callback.dirpath is None:
253255
callback.dirpath = Path(log_dir / "checkpoints")
254256
if callback.filename is None:

0 commit comments

Comments
 (0)