Skip to content

Commit 2128559

Browse files
mori360SalmanMohammadiebsmothers
authored
Fix optimizer_in_backward at loading opt_state_dict in distributed recipes (#2390)
Co-authored-by: Salman Mohammadi <[email protected]> Co-authored-by: Evan Smothers <[email protected]>
1 parent cb83655 commit 2128559

File tree

6 files changed

+28
-8
lines changed

6 files changed

+28
-8
lines changed

recipes/dev/early_exit_finetune_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def _setup_optimizer(
610610
for param in opt_state_dict.keys():
611611
try:
612612
training.load_from_full_optimizer_state_dict(
613-
self._optim_ckpt_wrapper.state_dict()[param],
613+
self._optim_ckpt_wrapper.optim_map[param],
614614
opt_state_dict[param],
615615
self._device,
616616
)

recipes/full_dpo_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def _setup_optimizer(
619619
try:
620620
training.load_from_full_optimizer_state_dict(
621621
self._model,
622-
self._optim_ckpt_wrapper.state_dict()[param],
622+
self._optim_ckpt_wrapper.optim_map[param],
623623
opt_state_dict[param],
624624
self._device,
625625
)

recipes/full_finetune_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def _setup_optimizer(
656656
try:
657657
training.load_from_full_optimizer_state_dict(
658658
self._model,
659-
self._optim_ckpt_wrapper.state_dict()[param],
659+
self._optim_ckpt_wrapper.optim_map[param],
660660
opt_state_dict[param],
661661
self._device,
662662
)

recipes/qat_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def _setup_optimizer(
564564
try:
565565
training.load_from_full_optimizer_state_dict(
566566
self._model,
567-
self._optim_ckpt_wrapper.state_dict()[param],
567+
self._optim_ckpt_wrapper.optim_map[param],
568568
opt_state_dict[param],
569569
self._device,
570570
)

tests/recipes/test_full_finetune_distributed.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def test_loss_single_rank(
267267
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
268268
[
269269
("llama3/8B_full", "llama3", "tune", 1, 4, False),
270+
("llama3/8B_full", "llama3", "tune", 4, 1, True),
270271
],
271272
)
272273
@gpu_test(gpu_count=2)
@@ -306,9 +307,17 @@ def test_training_state_on_resume(
306307
checkpointer.model_type={model_type.upper()} \
307308
tokenizer.path='{tokenizer_path}' \
308309
tokenizer.prompt_template=null \
309-
clip_grad_norm=100 \
310310
""".split()
311311

312+
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
313+
# wrong grad_norm, so we only test one of them each time. But loss values
314+
# should be the same.
315+
if not optim_in_bwd:
316+
cmd_1.append("clip_grad_norm=100")
317+
cmd_1.append("optimizer_in_bwd=False")
318+
else:
319+
cmd_1.append("optimizer_in_bwd=True")
320+
312321
model_config = MODEL_TEST_CONFIGS[model_type]
313322
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
314323

@@ -337,12 +346,17 @@ def test_training_state_on_resume(
337346
tokenizer.path='{tokenizer_path}' \
338347
tokenizer.prompt_template=null \
339348
resume_from_checkpoint=True \
340-
metric_logger.filename={log_file} \
341-
clip_grad_norm=100 \
349+
metric_logger.filename={log_file}
342350
""".split()
343351

344352
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
345353

354+
if not optim_in_bwd:
355+
cmd_2.append("clip_grad_norm=100")
356+
cmd_2.append("optimizer_in_bwd=False")
357+
else:
358+
cmd_2.append("optimizer_in_bwd=True")
359+
346360
monkeypatch.setattr(sys, "argv", cmd_2)
347361
runpy.run_path(TUNE_PATH, run_name="__main__")
348362

tests/recipes/test_full_finetune_single_device.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def test_loss(
137137
)
138138

139139
@pytest.mark.integration_test
140-
def test_training_state_on_resume(self, tmpdir, monkeypatch):
140+
@pytest.mark.parametrize(
141+
"optimizer_in_bwd",
142+
[True, False],
143+
)
144+
def test_training_state_on_resume(self, tmpdir, monkeypatch, optimizer_in_bwd):
141145
"""Test whether the recipe state is correctly updated on resume. Since this
142146
is model agnostic, we should run this on the small model only. The test
143147
consists of three stages:
@@ -169,6 +173,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
169173
checkpointer.model_type=LLAMA2 \
170174
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
171175
tokenizer.prompt_template=null \
176+
optimizer_in_bwd={optimizer_in_bwd} \
172177
""".split()
173178

174179
model_config = MODEL_TEST_CONFIGS["llama2"]
@@ -200,6 +205,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
200205
tokenizer.prompt_template=null \
201206
resume_from_checkpoint=True \
202207
metric_logger.filename={log_file} \
208+
optimizer_in_bwd={optimizer_in_bwd} \
203209
""".split()
204210

205211
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config

0 commit comments

Comments
 (0)