Skip to content

Conversation

@cdoern
Copy link
Contributor

@cdoern cdoern commented Jun 4, 2025

** PLEASE NOTE, THIS PR INCLUDES THE CHANGES IN #572, AND WILL BE REDUCED IN SIZE ONCE THAT MERGES **

Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization into classes. This commit introduces the Accelerator class

The Accelerator class aims to both store commonly accessed variables associated with the accelerated model and abstract model/optimizer mutation away from the user who should only access our Model and Optimizer classes.

These classes are one of a few steps needed to "SDK-ify" the training library

Adding structure to code via classes can either be someone's favorite or least favorite thing. So I figured I'd explain myself before continuing. Here is my rationale:

Classes provide logical structuring to code, especially code meant to be a publicly consumable SDK and allows you to associate related objects and methods with one another.

Being able to group functionality under the Model, Accelerator, and Checkpointer classes inherently reduces code complexity and duplication. Being able to store things like , self.distributed_framework,self.lora_config, etc in a way such that within the class they are accessible within different methods allows the arguments per method to go down drastically, as well as complex return values. Simpler methods and argument/return values allows for simpler testing of code.

@mergify mergify bot added testing Relates to testing ci-failure labels Jun 4, 2025
@cdoern cdoern force-pushed the refactor-accelerator branch from 1de1d23 to 25230a7 Compare June 4, 2025 17:26
@mergify mergify bot removed the ci-failure label Jun 4, 2025
@cdoern cdoern force-pushed the refactor-accelerator branch 2 times, most recently from aa7af32 to 81f731e Compare June 4, 2025 18:15
@mergify mergify bot added ci-failure and removed ci-failure labels Jun 4, 2025
cdoern added 2 commits June 4, 2025 17:22
Accelerator works with Model to abstract common utilities behind a custom class that allows users to seamlessly setup their model for training

Signed-off-by: Charlie Doern <[email protected]>
Signed-off-by: Charlie Doern <[email protected]>
@cdoern cdoern force-pushed the refactor-accelerator branch from 81f731e to bf4be9f Compare June 4, 2025 21:22
@github-actions
Copy link

github-actions bot commented Jun 4, 2025

E2E (NVIDIA L40S x4) (python 3.11) workflow launched on this PR: View run

@github-actions
Copy link

github-actions bot commented Jun 5, 2025

e2e workflow succeeded on this PR: View run, congrats!

class Accelerator:
def __init__(
self,
model: Model,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model acts as a "factory" class that creates nn.module once the from_pretrained method is called in the case of each; Liger, Dolomite and the normal transformer model. That is what we should pass into the Accelerator so that we can avoid the weird model.model references.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, I think by not passing in model: Model we lose a lot of the seamless nature of these classes. things like self.model.lora_config are not possible within the Accelerator class if model is type hinted to nn.module, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's correct. I guess we need to refine the model class more for that to happen; approving rn in the spirit of getting the refactor in quickly.

Copy link
Collaborator

@thisisatharva-rh thisisatharva-rh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to have a larger conversation about what the Model class should look like, but for now, this looks good.

@mergify mergify bot added the one-approval label Jun 5, 2025
@mergify mergify bot merged commit 3a2dcc1 into instructlab:main Jun 5, 2025
18 checks passed
@mergify mergify bot removed the one-approval label Jun 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

testing Relates to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants