-
Notifications
You must be signed in to change notification settings - Fork 683
Implement step based checkpointing #2869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement step based checkpointing #2869
Conversation
…d get resume working w/ StatefulDataLoader
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: salman <[email protected]>
…chtune into fix/torchtune_ckpt_tests
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2869 +/- ##
==========================================
+ Coverage 5.12% 59.24% +54.12%
==========================================
Files 375 439 +64
Lines 22956 27407 +4451
==========================================
+ Hits 1177 16238 +15061
+ Misses 21779 11169 -10610 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nits, just one question about "epoch" usage in checkpointer
| model_type: LLAMA3_2 | ||
| keep_last_n_checkpoints: 2 | ||
| resume_from_checkpoint: False | ||
| save_every_n_steps: 25 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were just test values, can probably remove
| loss: | ||
| _component_: torchtune.modules.loss.LinearCrossEntropyLoss | ||
| max_steps_per_epoch: null | ||
| max_steps_per_epoch: 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, just a test value
|
|
||
| if adapter_only: | ||
| save_path = dcp_saver.output_dir | ||
| if dir_prefix == "step": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth a comment that this is fairly hacky b/c we need to infer BC with epochs and potentially could be refactored
| adapter_only: bool = False, | ||
| single_device: bool = False, | ||
| *, | ||
| full_tensors: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment in the docstring explaining what full_tensors are and when you might want to use them.
| _ = state_dict.pop(training.ADAPTER_CONFIG, None) | ||
| output_path = Path.joinpath( | ||
| self._output_dir, RECIPE_STATE_DIRNAME, "recipe_state.pt" | ||
| self._output_dir, f"epoch_{epoch}", "recipe_state.pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why just epoch here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a debugger torchtune checkpointer, not a HF one
Context
What is the purpose of this PR? Is it to
Closes #2105. This is a widely requested feature that allows users to have greater control over checkpointing frequency in torchtune.
TODO: Add commentary on design decisions. Acknowledge spaghetti code. Beg forgiveness.
Changelog
FullModelHFCheckpointerto accept a step parameter when saving a checkpoint. Use that step to designate the checkpoint folder name. Keepepoch_{}as a fall-back for BC.full_finetune_single_device.pyrecipe to utilize step-based checkpointing.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install)pytest testspytest tests -m integration_testEvidence of correct number of checkpoints being saved
Evidence of correct resuming from ckpt mid-epoch

Evidence of correct resuming from ckpt at epoch boundary

UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example