Skip to content

Commit b0648fd

Browse files
committed
Step based checkpointing and tests
1 parent c5093b7 commit b0648fd

5 files changed

Lines changed: 205 additions & 20 deletions

File tree

recipes/full_finetune_single_device.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchtune.datasets import ConcatDataset
2424
from torchtune.recipe_interfaces import FTRecipeInterface
2525
from torchtune.training import DummyProfiler, PROFILER_KEY
26+
from torchtune.training.checkpointing.constants import CURR_STEP_KEY
2627
from torchtune.training.lr_schedulers import get_lr
2728

2829
from tqdm import tqdm
@@ -139,6 +140,7 @@ def __init__(self, cfg: DictConfig) -> None:
139140

140141
# Training cfg
141142
self._resume_from_checkpoint = cfg.resume_from_checkpoint
143+
self.save_every_n_steps = cfg.get("save_every_n_steps")
142144
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
143145
self._optimizer_in_bwd = cfg.optimizer_in_bwd
144146
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
@@ -324,6 +326,10 @@ def setup(self, cfg: DictConfig) -> None:
324326
self._steps_per_epoch = self.max_steps_per_epoch
325327
self.global_step = self.epochs_run * self._steps_per_epoch
326328

329+
# For now, default to saving at epoch boundaries
330+
if self.save_every_n_steps is None:
331+
self.save_every_n_steps = self._steps_per_epoch
332+
327333
# Setup lr scheduler
328334
self._lr_scheduler = self._setup_lr_scheduler(
329335
cfg_lr_scheduler=cfg.get("lr_scheduler", None),
@@ -596,30 +602,35 @@ def _setup_data(
596602

597603
return sampler, dataloader
598604

599-
def save_checkpoint(self, epoch: int) -> None:
605+
def save_checkpoint(self, *, epoch: int, step: int) -> None:
600606
"""
601607
Save state dict to file. The recipe save_checkpoint method is responsible for
602608
correctly creating the checkpoint dict and passing to the checkpointer.
603609
"""
604610
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}
605-
# if training is in-progress, checkpoint the optimizer state as well
606-
if epoch + 1 < self.total_epochs:
611+
612+
# If training is in-progress, checkpoint the optimizer state as well
613+
is_intermediate = step < self._steps_per_epoch * self.total_epochs
614+
if is_intermediate:
607615
ckpt_dict.update(
608616
{
609617
training.SEED_KEY: self.seed,
610-
training.EPOCHS_KEY: self.epochs_run,
618+
training.EPOCHS_KEY: epoch,
611619
training.TOTAL_EPOCHS_KEY: self.total_epochs,
620+
CURR_STEP_KEY: step,
612621
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
613622
}
614623
)
615624
if not self._optimizer_in_bwd:
616625
ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict()
617626
else:
618627
ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
628+
619629
self._checkpointer.save_checkpoint(
620630
ckpt_dict,
621631
epoch=epoch,
622-
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
632+
intermediate_checkpoint=is_intermediate,
633+
step=step,
623634
)
624635

625636
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
@@ -753,6 +764,13 @@ def train(self) -> None:
753764
step=self.global_step,
754765
)
755766

767+
# Save checkpoint if specified by user
768+
if (
769+
self.global_step > 0
770+
and self.global_step % self.save_every_n_steps == 0
771+
):
772+
self.save_checkpoint(epoch=curr_epoch, step=self.global_step)
773+
756774
# Reset running stats for the next step
757775
running_loss = 0
758776
num_tokens = 0
@@ -776,8 +794,8 @@ def train(self) -> None:
776794
self._profiler.step()
777795

778796
self.epochs_run += 1
779-
self.save_checkpoint(epoch=curr_epoch)
780797

798+
self.save_checkpoint(epoch=curr_epoch, step=self.global_step)
781799
self._profiler.stop()
782800

783801
def cleanup(self) -> None:

tests/recipes/test_full_finetune_single_device.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
import re
89

910
import runpy
1011

1112
import sys
1213
from pathlib import Path
1314

1415
import pytest
15-
1616
import torch
1717
from tests.common import TUNE_PATH
1818

@@ -214,3 +214,128 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
214214
torch.testing.assert_close(
215215
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
216216
)
217+
218+
@pytest.mark.integration_test
219+
@pytest.mark.parametrize("keep_last_n_checkpoints", [1, 2])
220+
@pytest.mark.parametrize("save_every_n_steps", [1, 2])
221+
def test_checkpointing_with_steps(
222+
self, tmpdir, monkeypatch, keep_last_n_checkpoints, save_every_n_steps
223+
):
224+
ckpt = "llama2_hf"
225+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
226+
ckpt_dir = ckpt_path.parent
227+
log_file = gen_log_file_name(tmpdir)
228+
write_hf_ckpt_config(tmpdir)
229+
230+
# Train for two epochs (anywhere from 2 -> 4 ckpts)
231+
cmd_1 = f"""
232+
tune run full_finetune_single_device \
233+
--config llama2/7B_full_low_memory \
234+
batch_size=8 \
235+
output_dir={tmpdir} \
236+
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
237+
checkpointer.checkpoint_dir='{ckpt_dir}' \
238+
checkpointer.checkpoint_files=[{ckpt_path}]\
239+
checkpointer.output_dir={tmpdir} \
240+
checkpointer.model_type=LLAMA2 \
241+
checkpointer.keep_last_n_checkpoints={keep_last_n_checkpoints} \
242+
save_every_n_steps={save_every_n_steps} \
243+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
244+
tokenizer.prompt_template=null \
245+
""".split()
246+
model_config = MODEL_TEST_CONFIGS["llama2"]
247+
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
248+
monkeypatch.setattr(sys, "argv", cmd_1)
249+
with pytest.raises(SystemExit, match=""):
250+
runpy.run_path(TUNE_PATH, run_name="__main__")
251+
252+
regex_to_match = re.compile("step_([0-9]+)")
253+
# Iterate over the directory contents, find all directories that match
254+
# `regex_to_match`. Assert that the number of directories found is equal
255+
# to the `keep_last_n_checkpoints` value. Also assert that each checkpoint
256+
# number is a multiple of `save_every_n_steps`.
257+
ckpt_dirs = [
258+
d
259+
for d in os.listdir(tmpdir)
260+
if os.path.isdir(os.path.join(tmpdir, d)) and regex_to_match.match(d)
261+
]
262+
assert len(ckpt_dirs) == keep_last_n_checkpoints
263+
for ckpt_dir in ckpt_dirs:
264+
step = int(regex_to_match.match(ckpt_dir).group(1))
265+
assert step % save_every_n_steps == 0
266+
267+
# Also make sure that the last checkpoint has the correct number of steps
268+
most_recent_checkpoint = get_largest_iter_folder(tmpdir, pattern=r"^step_(\d+)")
269+
step = int(regex_to_match.match(most_recent_checkpoint).group(1))
270+
assert step == 4 # 2 epochs * 2 steps per epoch
271+
272+
@pytest.mark.integration_test
273+
def test_checkpointing_with_steps_and_resume(self, tmpdir, monkeypatch):
274+
"""We want to be sure that now we use steps, we can resume correctly from a checkpoint.
275+
Once we fully transition to steps, we can remove the test above."""
276+
# 0. Set up variables
277+
ckpt = "llama2_hf"
278+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
279+
ckpt_dir = ckpt_path.parent
280+
log_file = gen_log_file_name(tmpdir)
281+
write_hf_ckpt_config(ckpt_dir)
282+
write_hf_ckpt_config(tmpdir)
283+
284+
# 1. Train for two epochs, keep 2 checkpoints
285+
cmd_1 = f"""
286+
tune run full_finetune_single_device \
287+
--config llama2/7B_full_low_memory \
288+
batch_size=8 \
289+
output_dir={tmpdir} \
290+
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
291+
checkpointer.checkpoint_dir='{ckpt_dir}' \
292+
checkpointer.checkpoint_files=[{ckpt_path}]\
293+
checkpointer.output_dir={tmpdir} \
294+
checkpointer.model_type=LLAMA2 \
295+
checkpointer.keep_last_n_checkpoints=2 \
296+
save_every_n_steps=2 \
297+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
298+
tokenizer.prompt_template=null \
299+
""".split()
300+
model_config = MODEL_TEST_CONFIGS["llama2"]
301+
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
302+
monkeypatch.setattr(sys, "argv", cmd_1)
303+
with pytest.raises(SystemExit, match=""):
304+
runpy.run_path(TUNE_PATH, run_name="__main__")
305+
306+
# 2. Find the checkpoint at the end of the first epoch
307+
step_folder = get_largest_iter_folder(tmpdir, pattern=r"^step_(\d+)")
308+
step_folder_at_epoch_boundary = f"step_{int(step_folder.split('_')[-1]) - 2}"
309+
suffix = ".safetensors"
310+
model_ckpt_fname = (
311+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
312+
)
313+
314+
# 3. Resume training w/ the checkpoint from epoch boundary
315+
cmd_2 = f"""
316+
tune run full_finetune_single_device \
317+
--config llama2/7B_full_low_memory \
318+
batch_size=8 \
319+
output_dir={tmpdir} \
320+
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
321+
checkpointer.checkpoint_dir={ckpt_dir} \
322+
checkpointer.checkpoint_files=[{os.path.join(step_folder_at_epoch_boundary, model_ckpt_fname)}]\
323+
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
324+
checkpointer.output_dir={tmpdir} \
325+
checkpointer.model_type=LLAMA2 \
326+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
327+
tokenizer.prompt_template=null \
328+
resume_from_checkpoint=True \
329+
metric_logger.filename={log_file} \
330+
""".split()
331+
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
332+
monkeypatch.setattr(sys, "argv", cmd_2)
333+
with pytest.raises(SystemExit, match=""):
334+
runpy.run_path(TUNE_PATH, run_name="__main__")
335+
336+
# 4. Make sure loss values match the expected values
337+
expected_loss_values = self._fetch_expected_loss_values("llama2")[2:]
338+
loss_values = get_loss_values_from_metric_logger(log_file)
339+
torch.testing.assert_close(
340+
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
341+
)

torchtune/training/checkpointing/_checkpointer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232
check_outdir_not_in_ckptdir,
3333
copy_files,
3434
get_adapter_checkpoint_path,
35+
get_all_checkpoints_in_dir,
3536
get_model_checkpoint_path,
3637
get_recipe_checkpoint_path,
3738
ModelType,
39+
prune_surplus_checkpoints,
3840
RECIPE_STATE_DIRNAME,
3941
REPO_ID_FNAME,
4042
safe_torch_load,
@@ -399,6 +401,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
399401
Default is True.
400402
should_load_recipe_state (bool): If True, the checkpointer will load the additional checkpoint files corresponding to
401403
the receipe state from a previous run. Default is False
404+
keep_last_n_checkpoints (Optional[int]): How many checkpoints to keep. If None, all checkpoints are kept.
402405
"""
403406

404407
def __init__(
@@ -412,6 +415,8 @@ def __init__(
412415
resume_from_checkpoint: bool = False,
413416
safe_serialization: bool = True,
414417
should_load_recipe_state: bool = False,
418+
*,
419+
keep_last_n_checkpoints: Optional[int] = None,
415420
) -> None:
416421

417422
self._should_load_recipe_state = should_load_recipe_state
@@ -420,6 +425,7 @@ def __init__(
420425
logger.warning(
421426
"*resume_from_checkpoint is deprecated. Please use the 'should_load_recipe_state' instead"
422427
)
428+
self._keep_last_n_checkpoints = keep_last_n_checkpoints
423429

424430
self._safe_serialization = safe_serialization
425431
self._checkpoint_dir = Path(checkpoint_dir)
@@ -457,7 +463,7 @@ def __init__(
457463
output_dir=self._output_dir,
458464
adapter_checkpoint=adapter_checkpoint,
459465
should_load_recipe_state=self._should_load_recipe_state,
460-
pattern=r"^epoch_(\d+)",
466+
pattern=r"^step_(\d+)",
461467
)
462468

463469
# resume recipe_state ckpt
@@ -629,6 +635,8 @@ def save_checkpoint(
629635
epoch: int,
630636
intermediate_checkpoint: bool = False,
631637
adapter_only: bool = False,
638+
*,
639+
step: Optional[int] = None,
632640
) -> None:
633641
"""
634642
Save HF checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
@@ -644,10 +652,19 @@ def save_checkpoint(
644652
intermediate_checkpoint (bool): If True, an additional checkpoint files for recipe state
645653
and (if applicable) adapter weights are created. Default is False
646654
adapter_only (bool): If True, only save the adapter weights. Default is False
655+
step (Optional[int]): Step number. Used to create the checkpoint file name if provided.
647656
648657
Raises:
649658
ValueError: if ``adapter_only`` is True and adapter checkpoint not found in state_dict.
650659
"""
660+
# Prefer to use step, not epoch
661+
if step is not None:
662+
ckpt_save_dirname = f"step_{step}"
663+
ckpt_pattern = r"^step_(\d+)"
664+
else:
665+
ckpt_save_dirname = f"epoch_{epoch}"
666+
ckpt_pattern = r"^epoch_(\d+)"
667+
651668
# convert the state_dict back to hf format; do this inplace
652669
if not adapter_only:
653670
if self._model_type == ModelType.PHI3_MINI:
@@ -747,7 +764,7 @@ def save_checkpoint(
747764
)
748765
map_original_name_to_new_name[cpt_idx] = shard_name
749766
output_path = Path.joinpath(
750-
self._output_dir, f"epoch_{epoch}", shard_name
767+
self._output_dir, ckpt_save_dirname, shard_name
751768
)
752769
output_path.parent.mkdir(parents=True, exist_ok=True)
753770
if not self._safe_serialization:
@@ -779,7 +796,7 @@ def save_checkpoint(
779796
index_file_name = TORCH_INDEX_FNAME
780797

781798
index_path = Path.joinpath(
782-
self._output_dir, f"epoch_{epoch}", index_file_name
799+
self._output_dir, ckpt_save_dirname, index_file_name
783800
)
784801

785802
index_data = {
@@ -796,7 +813,7 @@ def save_checkpoint(
796813
# convert_weights.peft_to_tune. The .pt format is not needed, but
797814
# it is an easy way to distinguish the adapters. Ideally we should save only one.
798815
output_path = Path.joinpath(
799-
self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME
816+
self._output_dir, ckpt_save_dirname, ADAPTER_MODEL_FNAME
800817
).with_suffix(".pt")
801818
output_path.parent.mkdir(parents=True, exist_ok=True)
802819
torch.save(state_dict[training.ADAPTER_KEY], output_path)
@@ -825,7 +842,7 @@ def save_checkpoint(
825842
head_dim=self._config.get("head_dim", None),
826843
)
827844
output_path = Path.joinpath(
828-
self._output_dir, f"epoch_{epoch}", ADAPTER_MODEL_FNAME
845+
self._output_dir, ckpt_save_dirname, ADAPTER_MODEL_FNAME
829846
)
830847
output_path.parent.mkdir(parents=True, exist_ok=True)
831848
if not self._safe_serialization:
@@ -866,7 +883,7 @@ def save_checkpoint(
866883
)
867884

868885
output_path = Path.joinpath(
869-
self._output_dir, f"epoch_{epoch}", ADAPTER_CONFIG_FNAME
886+
self._output_dir, ckpt_save_dirname, ADAPTER_CONFIG_FNAME
870887
).with_suffix(".json")
871888
with open(output_path, "w") as f:
872889
json.dump(state_dict[training.ADAPTER_CONFIG], f)
@@ -880,7 +897,7 @@ def save_checkpoint(
880897
# So its easy to run inference with the model using this epoch's checkpoint
881898
copy_files(
882899
self._checkpoint_dir,
883-
Path.joinpath(self._output_dir, f"epoch_{epoch}"),
900+
Path.joinpath(self._output_dir, ckpt_save_dirname),
884901
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
885902
)
886903

@@ -901,7 +918,7 @@ def save_checkpoint(
901918
f"saved to {output_path}"
902919
)
903920
else:
904-
logger.info("Saving final epoch checkpoint.")
921+
logger.info("Saving final checkpoint.")
905922
if adapter_only:
906923
logger.info(
907924
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
@@ -914,6 +931,16 @@ def save_checkpoint(
914931
"You can now use this checkpoint for further training or inference."
915932
)
916933

934+
# If specified, prune the checkpoints in the output directory
935+
if self._keep_last_n_checkpoints is not None:
936+
all_current_checkpoints = get_all_checkpoints_in_dir(
937+
self._output_dir, pattern=ckpt_pattern
938+
)
939+
prune_surplus_checkpoints(
940+
all_current_checkpoints,
941+
keep_last_n_checkpoints=self._keep_last_n_checkpoints,
942+
)
943+
917944

918945
class FullModelMetaCheckpointer(_CheckpointerInterface):
919946
"""

0 commit comments

Comments
 (0)