Skip to content

Conversation

@bogdansalyp
Copy link
Collaborator

@bogdansalyp bogdansalyp commented Jul 7, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Closes #2105. This is a widely requested feature that allows users to have greater control over checkpointing frequency in torchtune.

TODO: Add commentary on design decisions. Acknowledge spaghetti code. Beg forgiveness.

Changelog

  • Update FullModelHFCheckpointer to accept a step parameter when saving a checkpoint. Use that step to designate the checkpoint folder name. Keep epoch_{} as a fall-back for BC.
  • Modify the full_finetune_single_device.py recipe to utilize step-based checkpointing.
  • Add tests for `full_finetune_single_device.py`` recipe w/ step-based checkpointing.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Evidence of correct number of checkpoints being saved

(joe-torchtune) [[email protected] ~/projects/joe-torchtune (impl-step-based-ckpt)]$ ls /tmp/torchtune/llama3_2_1B/full_single_device/
step_100  step_125  step_150  step_175  step_200  step_25  step_50  step_75  torchtune_config.yaml

Evidence of correct resuming from ckpt mid-epoch
Screenshot 2025-02-28 at 4 59 52 PM

Evidence of correct resuming from ckpt at epoch boundary
Screenshot 2025-02-28 at 5 00 19 PM

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

joecummings and others added 30 commits February 27, 2025 13:41
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: salman <[email protected]>
@bogdansalyp bogdansalyp marked this pull request as ready for review July 10, 2025 16:57
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 17.20648% with 409 lines in your changes missing coverage. Please review.

Project coverage is 59.24%. Comparing base (7d30d4c) to head (243a6ab).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/checkpointing/_checkpointer.py 49.49% 50 Missing ⚠️
...htune/training/checkpointing/_checkpoint_client.py 0.00% 47 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 46 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 38 Missing ⚠️
tests/recipes/test_full_finetune_single_device.py 21.62% 29 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 25 Missing ⚠️
recipes/knowledge_distillation_single_device.py 0.00% 21 Missing ⚠️
tests/recipes/test_qat_distributed.py 19.23% 21 Missing ⚠️
recipes/full_dpo_distributed.py 0.00% 20 Missing ⚠️
recipes/knowledge_distillation_distributed.py 0.00% 18 Missing ⚠️
... and 16 more
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2869       +/-   ##
==========================================
+ Coverage   5.12%   59.24%   +54.12%     
==========================================
  Files        375      439       +64     
  Lines      22956    27407     +4451     
==========================================
+ Hits        1177    16238    +15061     
+ Misses     21779    11169    -10610     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@bogdansalyp bogdansalyp changed the title [DEBUG] Step-based checkpointing fixes Implement step based checkpointing Jul 15, 2025
@bogdansalyp bogdansalyp marked this pull request as draft July 15, 2025 18:12
@bogdansalyp bogdansalyp marked this pull request as ready for review July 15, 2025 18:40
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits, just one question about "epoch" usage in checkpointer

model_type: LLAMA3_2
keep_last_n_checkpoints: 2
resume_from_checkpoint: False
save_every_n_steps: 25
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were just test values, can probably remove

loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: null
max_steps_per_epoch: 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, just a test value


if adapter_only:
save_path = dcp_saver.output_dir
if dir_prefix == "step":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth a comment that this is fairly hacky b/c we need to infer BC with epochs and potentially could be refactored

adapter_only: bool = False,
single_device: bool = False,
*,
full_tensors: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment in the docstring explaining what full_tensors are and when you might want to use them.

_ = state_dict.pop(training.ADAPTER_CONFIG, None)
output_path = Path.joinpath(
self._output_dir, RECIPE_STATE_DIRNAME, "recipe_state.pt"
self._output_dir, f"epoch_{epoch}", "recipe_state.pt"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why just epoch here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a debugger torchtune checkpointer, not a HF one

@bogdansalyp bogdansalyp merged commit e43b6e6 into meta-pytorch:main Jul 15, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants