Skip to content

Commit 70b2ddf

Browse files
authored
fix for MCore dist ckpt loading (#14229)
Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
1 parent c6c71de commit 70b2ddf

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

nemo/collections/speechlm/models/speech_to_text_llm_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _freeze_module(self, module: Optional[nn.Module] = None) -> None:
195195
for param in module.parameters():
196196
param.requires_grad = False
197197

198-
def _maybe_load_pretrained_llm(self, model: MCoreGPTModel) -> MCoreGPTModel:
198+
def _maybe_load_pretrained_llm(self, model: MCoreGPTModel, strict: bool = False) -> MCoreGPTModel:
199199
if not self.language_model_from_pretrained:
200200
return model
201201

@@ -224,6 +224,7 @@ def _maybe_load_pretrained_llm(self, model: MCoreGPTModel) -> MCoreGPTModel:
224224
sharded_state_dict=sharded_state_dict,
225225
checkpoint_dir=ckpt_to_weights_subdir(ckpt_path, is_saving=False),
226226
validate_access_integrity=False,
227+
**({"strict": "log_all"} if not strict else {}),
227228
)
228229
loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()}
229230
model.load_state_dict(loaded_state_dict)

0 commit comments

Comments
 (0)