diff --git a/nemo/collections/llm/modelopt/model_utils.py b/nemo/collections/llm/modelopt/model_utils.py index 86f981baa628..70209580323f 100644 --- a/nemo/collections/llm/modelopt/model_utils.py +++ b/nemo/collections/llm/modelopt/model_utils.py @@ -234,6 +234,7 @@ def restore_modelopt_state( mto.plugins.restore_sharded_modelopt_state( [core_model], ckpt_to_weights_subdir(path, is_saving=False), + prefix="module.", ) if mto.ModeloptStateManager.is_converted(core_model): logging.info("Restored Model Optimizer state from checkpoint.")