diff --git a/nemo/collections/llm/modelopt/speculative/model_transform.py b/nemo/collections/llm/modelopt/speculative/model_transform.py index c30ab1a759f1..9a165d8d87f0 100644 --- a/nemo/collections/llm/modelopt/speculative/model_transform.py +++ b/nemo/collections/llm/modelopt/speculative/model_transform.py @@ -16,17 +16,19 @@ from nemo.collections.llm import GPTModel from nemo.utils import logging -from nemo.utils.import_utils import safe_import +from nemo.utils.import_utils import UnavailableError, safe_import from nemo.utils.model_utils import unwrap_model mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt") mtsp, _ = safe_import("modelopt.torch.speculative") - -ALGORITHMS = { - "eagle3": mtsp.EAGLE3_DEFAULT_CFG, - # more TBD -} +try: + ALGORITHMS = { + "eagle3": mtsp.EAGLE3_DEFAULT_CFG, + # more TBD + } +except UnavailableError: + ALGORITHMS = {} def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> nn.Module: diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 02cc39975d2a..2fa6873b0d5d 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -22,12 +22,14 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint from lightning.pytorch.loggers import Logger, TensorBoardLogger, WandbLogger -from nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback import LocalCheckpointCallback from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.callbacks import ModelCheckpoint from nemo.utils import logging from nemo.utils.app_state import AppState +from nemo.utils.import_utils import safe_import + +LocalCheckpointCallback, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback') @dataclass