Skip to content

feat: add post-training and custom-training support#31

Merged
pjannaty merged 10 commits intomainfrom
qianlim/posttrain
Apr 17, 2025
Merged

feat: add post-training and custom-training support#31
pjannaty merged 10 commits intomainfrom
qianlim/posttrain

Conversation

@qianlim
Copy link
Contributor

@qianlim qianlim commented Apr 10, 2025

This PR brings support for training the cosmos-transfer models. Supports single or multi-node training with Tensor Parallel and Sequence Parallel. Supports both training customized models from scratch and post-training / fine-tuning from the released checkpoints.

It addresses Issue #3.

Since a large number of files are added, and the model classes are also updated (to enable support for training), careful review and testing is needed.

So far I've verified the generated config yaml aligns with the one we use to train the released models. I'll launch training jobs to verify the correctness of implementation.

[WIP]: test the training script, improve the README.

@qianlim qianlim added the enhancement New feature or request label Apr 10, 2025
@qianlim qianlim self-assigned this Apr 10, 2025
@qianlim qianlim requested a review from arieling April 10, 2025 17:59
@qianlim qianlim requested a review from tiffanycai6 April 15, 2025 03:33
if isinstance(control_weight, torch.Tensor):
if control_weight.ndim == 0: # Single scalar tensor
control_weight = [float(control_weight)]
control_weight = [float(control_weight)] * len(guided_hints)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tcwang0509 This line is different in current repo and i4. Could you advise?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure we can add it. Currently when ndim==0, len(guided_hints) is always 1, but it's good to make it general.

# Reshape to match THWBD format
weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1]
hint_val = control_feat * weight_map * gate
weight_map = weight_map.view(T * H * W, 1, 1, B, 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tcwang0509 also this line, i4 has this additional step of reshaping the weight_map but removed in the repo. Is it needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because in i4 we use TP but here we use CP.

@pjannaty
Copy link
Contributor

Thanks Qianli. Can we please trigger CI?

@pjannaty
Copy link
Contributor

pjannaty commented Apr 17, 2025

LGTM. Thank you, Qianli!

CI tests pass
image

Copy link
Contributor

@pjannaty pjannaty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add training tests in a follow up. Merging for now.

@pjannaty pjannaty merged commit 68d5d67 into main Apr 17, 2025
2 checks passed
atmguille pushed a commit to atmguille/cosmos-transfer1 that referenced this pull request Jul 16, 2025
* feat: add post-training and custom-training support

* feat: add separate model definitions supporting tp/sp for training; update configs

* feat: add example Dataset class, add data augmentors, update config

* feat: add example data class, add misc improvements to data loading and config, add script to convert ckpt to tp

* fix: fix conflict in DiTEncoder

* cleanup

* feat: compelete README in examples/ for post/pre-training; update the main README

* fix: multiple minor fixes on example dataset

* fix: multiple minor fixes + improve example dataset performance

* feat+fix: multiple fixes + refinements to README
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants