Skip to content

Commit bfb205b

Browse files
kunlunlko3n1g
authored andcommitted
Load master weights from checkpoint (#14072)
* Load master weights from checkpoint Signed-off-by: kunlunl <[email protected]> * Change the default behavior to not load the main parameters from the checkpoint. Signed-off-by: kunlunl <[email protected]> * Add tests to cover all the added code Signed-off-by: kunlunl <[email protected]> * Apply isort and black reformatting Signed-off-by: ko3n1g <[email protected]> --------- Signed-off-by: kunlunl <[email protected]> Signed-off-by: ko3n1g <[email protected]> Co-authored-by: ko3n1g <[email protected]>
1 parent 9882d12 commit bfb205b

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

nemo/core/optim/mcore_optim.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,14 @@ def zero_grad(self, set_to_none: bool = True):
6464
"""
6565
self.mcore_optimizer.zero_grad(set_to_none)
6666

67-
def reload_model_params(self):
67+
def reload_model_params(self, state_dict=None):
6868
"""
6969
Reloads model parameters from the optimizer.
7070
"""
71-
self.mcore_optimizer.reload_model_params()
71+
if state_dict is None:
72+
self.mcore_optimizer.reload_model_params()
73+
else:
74+
self.mcore_optimizer.reload_model_params(state_dict=state_dict)
7275

7376
def state_dict(self):
7477
"""

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
170170
ckpt_type (TrainerCkptProtocol): Checkpoint type. Defaults to TrainerCheckpoint.
171171
ckpt_load_optimizer (bool): Load optimizer state from trainer.ckpt_path. Defaults to True.
172172
ckpt_save_optimizer (bool): Save optimizer states in checkpoint. Defaults to True.
173+
ckpt_load_main_params (bool): Load main parameters from trainer.ckpt_path. Defaults to False.
173174
ddp (Union[DDPLiteral, DistributedDataParallelConfig]): DDP configuration. Defaults to "megatron".
174175
fsdp (Optional[FSDPLiteral]): Option of using torch FSDP2, select from ["megatron", "pytorch"].
175176
Defaults to None.
@@ -257,6 +258,7 @@ def __init__(
257258
find_unused_parameters: bool = False,
258259
ckpt_load_optimizer: bool = True,
259260
ckpt_save_optimizer: bool = True,
261+
ckpt_load_main_params: bool = False,
260262
ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
261263
fsdp: Optional[FSDPLiteral] = None,
262264
lazy_init: bool = False,
@@ -319,6 +321,7 @@ def __init__(
319321
self.lazy_init = lazy_init
320322
self.ckpt_load_optimizer = ckpt_load_optimizer
321323
self.ckpt_save_optimizer = ckpt_save_optimizer
324+
self.ckpt_load_main_params = ckpt_load_main_params
322325
self.ckpt_load_strictness = ckpt_load_strictness
323326
self.use_te_rng_tracker = use_te_rng_tracker
324327
self.use_sharp = use_sharp
@@ -391,6 +394,9 @@ def __init__(
391394
else:
392395
raise ValueError(f"Invalid DDP type: {ddp}")
393396

397+
if self.ckpt_load_optimizer and self.ckpt_load_main_params:
398+
raise ValueError("ckpt_load_optimizer and ckpt_load_main_params cannot be both set to True.")
399+
394400
if isinstance(self.ddp_config, DistributedDataParallelConfig):
395401
self.ddp_config.num_distributed_optimizer_instances = self.num_distributed_optimizer_instances
396402

@@ -1052,7 +1058,13 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
10521058

10531059
if not 'optimizer' in checkpoint:
10541060
for opt in self.optimizers:
1055-
opt.reload_model_params()
1061+
if self.ckpt_load_main_params:
1062+
if "state_dict" in checkpoint:
1063+
opt.reload_model_params(checkpoint["state_dict"])
1064+
else:
1065+
opt.reload_model_params(checkpoint)
1066+
else:
1067+
opt.reload_model_params()
10561068

10571069
@property
10581070
@override

tests/lightning/pytorch/strategies/test_megatron_strategy.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from unittest.mock import patch
15+
from unittest.mock import MagicMock, patch
16+
17+
import pytest
1618

1719
from nemo.lightning.pytorch.strategies import MegatronStrategy
1820

@@ -40,3 +42,28 @@ class Dummy: ...
4042

4143
assert first_io != second_io
4244
assert second_io == strategy2.checkpoint_io
45+
46+
def test_ckpt_load_main_params_and_ckpt_load_optimizer_both_true(self):
47+
# Make sure ckpt_load_optimizer and ckpt_load_main_params cannot be both set to True.
48+
with pytest.raises(ValueError):
49+
strategy = MegatronStrategy(ckpt_load_optimizer=True, ckpt_load_main_params=True)
50+
51+
def test_ckpt_load_main_params_with_state_dict(self):
52+
# Test ckpt_load_main_params with "state_dict" key.
53+
strategy = MegatronStrategy()
54+
strategy.ckpt_load_main_params = True
55+
strategy.megatron_parallel = MagicMock()
56+
strategy.optimizers = [MagicMock()]
57+
checkpoint = {"state_dict": {"param": 1}}
58+
strategy.load_model_state_dict(checkpoint)
59+
strategy.optimizers[0].reload_model_params.assert_called_once_with(checkpoint["state_dict"])
60+
61+
def test_ckpt_load_main_params_without_state_dict(self):
62+
# Test ckpt_load_main_params with "state_dict" key.
63+
strategy = MegatronStrategy()
64+
strategy.ckpt_load_main_params = True
65+
strategy.megatron_parallel = MagicMock()
66+
strategy.optimizers = [MagicMock()]
67+
checkpoint = {"other": 123}
68+
strategy.load_model_state_dict(checkpoint)
69+
strategy.optimizers[0].reload_model_params.assert_called_once_with(checkpoint)

0 commit comments

Comments
 (0)