Skip to content

Conversation

@joecummings
Copy link
Member

@joecummings joecummings commented Feb 18, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This PR adds support for the StatefulDataLoader class from the PyTorch library torchdata.

FAQs

  1. Why? This is necessary for resuming from checkpoints mid-epoch - a feature set we will need to support for step-based checkpointing.
  2. Only full finetune single device? Yeah, this is more to see how we'll integrate this as a POC and make sure that the tests can pass. Once this is merged, I'll add support for the rest of the recipes.
  3. Hardcoding iterator_finished? Yeah, we'll change this when we actually move to step-based checkpointing, but right now we expect to save on the epoch boundaries so if the epoch is cut short, then we expect that the dataloader will restart its shuffling and data provided as if the iterator has finished going through all the samples. Huge, huge thanks to @ramanishsingh for helping me debug this last issue.
  4. You removed the sampler?!! The StatefulDataLoader creates a batched random sampler of it's own so there's no need for us to be creating a new one in this context. Less code means my life is easier.
  5. WTH, you removed the check for max_steps is None. Won't that break? Nope, Python is dumb. Check it:
>>> 1 == None
False
>>> 0 == None
False
>>> 5 == None
False
>>> 1000000000000000 == None
False

Changelog

What are the changes made in this PR?

  • Import StatefulDataloader and replace it's usage in the _setup_data method
  • Checkpoint the dataloader state dict
  • Load the dataloader state dict if we are trying to resume from the checkpoint
  • Update tests to match (yes the numbers changed b/c we are using a new random state - it's fine)

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2410

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c7decd7 with merge base 952078e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 18, 2025
@joecummings joecummings force-pushed the add-support-for-stateful-dl branch from 22e007c to fe0380d Compare February 18, 2025 20:07
batch_size=cfg.batch_size,
collate_fn=collate_name,
dataloader_state_dict=(
ckpt_dict[training.DATALOADER_KEY]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should check if this key even exists for BC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might just have to break BC b/c without this, the user will not be able to successfully resume training at any point.

num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so how is seed passed to StatefulDataLoader in this case?

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 5.88235% with 16 lines in your changes missing coverage. Please review.

Project coverage is 23.16%. Comparing base (e6cba25) to head (c7decd7).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_single_device.py 0.00% 11 Missing ⚠️
tests/recipes/test_full_finetune_single_device.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2410       +/-   ##
===========================================
- Coverage   63.87%   23.16%   -40.72%     
===========================================
  Files         368      379       +11     
  Lines       21873    22706      +833     
===========================================
- Hits        13971     5259     -8712     
- Misses       7902    17447     +9545     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings joecummings merged commit 7b654ea into meta-pytorch:main Feb 24, 2025
17 checks passed
joecummings added a commit to joecummings/torchtune that referenced this pull request Feb 27, 2025
joecummings added a commit to joecummings/torchtune that referenced this pull request Feb 27, 2025
pbontrager pushed a commit to pbontrager/torchtune that referenced this pull request Mar 17, 2025
pbontrager pushed a commit that referenced this pull request Mar 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants