Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions nemo/collections/llm/modelopt/speculative/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

from nemo.collections.llm import GPTModel
from nemo.utils import logging
from nemo.utils.import_utils import safe_import
from nemo.utils.import_utils import UnavailableError, safe_import
from nemo.utils.model_utils import unwrap_model

mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
mtsp, _ = safe_import("modelopt.torch.speculative")


ALGORITHMS = {
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
# more TBD
}
try:
ALGORITHMS = {
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
# more TBD
}
except UnavailableError:
ALGORITHMS = {}


def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> nn.Module:
Expand Down
4 changes: 3 additions & 1 deletion nemo/lightning/nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
from lightning.pytorch.loggers import Logger, TensorBoardLogger, WandbLogger
from nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback import LocalCheckpointCallback

from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.import_utils import safe_import

LocalCheckpointCallback, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback')


@dataclass
Expand Down
Loading