Skip to content

Commit ca0e9ec

Browse files
chtruong814roclark
authored andcommitted
Revert "Safely import optional python packages (NVIDIA-NeMo#13936)" (NVIDIA-NeMo#14197)
This reverts commit e04ee42. Signed-off-by: Charlie Truong <[email protected]> Co-authored-by: Robert Clark <[email protected]> Signed-off-by: Amir Hussein <[email protected]>
1 parent 5af22b1 commit ca0e9ec

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

nemo/collections/llm/modelopt/speculative/model_transform.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,17 @@
1616

1717
from nemo.collections.llm import GPTModel
1818
from nemo.utils import logging
19-
from nemo.utils.import_utils import UnavailableError, safe_import
19+
from nemo.utils.import_utils import safe_import
2020
from nemo.utils.model_utils import unwrap_model
2121

2222
mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
2323
mtsp, _ = safe_import("modelopt.torch.speculative")
2424

25-
try:
26-
ALGORITHMS = {
27-
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
28-
# more TBD
29-
}
30-
except UnavailableError:
31-
ALGORITHMS = {}
25+
26+
ALGORITHMS = {
27+
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
28+
# more TBD
29+
}
3230

3331

3432
def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> nn.Module:

nemo/lightning/nemo_logger.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@
2222
import lightning.pytorch as pl
2323
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
2424
from lightning.pytorch.loggers import Logger, TensorBoardLogger, WandbLogger
25+
from nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback import LocalCheckpointCallback
2526

2627
from nemo.lightning.io.mixin import IOMixin
2728
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
2829
from nemo.utils import logging
2930
from nemo.utils.app_state import AppState
30-
from nemo.utils.import_utils import safe_import
31-
32-
LocalCheckpointCallback, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback')
3331

3432

3533
@dataclass

0 commit comments

Comments
 (0)