-
Notifications
You must be signed in to change notification settings - Fork 632
Support rope cache indexing using positions #2112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! May I ask: for the dsv3 16B test on dp4tp2 before vs. after, did you explicitly pass positions into the model? We should try both (1, seq_len) and (batch_size, seq_len) inputs (they could be the trivial 0 -> seq_len - 1 ids).
Also had an inline comment.
| "attention": prepare_module_input( | ||
| input_layouts=(Shard(1), Replicate(), None), | ||
| desired_input_layouts=(Replicate(), Replicate(), None), | ||
| input_layouts=(Shard(1), Replicate(), None, None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that when positions is not None, this is making implicit assumption that positions has the the expected sharding when it's used, namely
sharded on batch dim by DP, replicate on TP mesh, sharded on seq dim by CP
I don't have a good solution right now -- making it Replicate by default will fail here when positions is None https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L521
but clearly this is leaving a footgun. I'd suggest we add a comment for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah when testing positions [1, seq_len] and [bz, seqlen] in dp4tp2, I need to manually change both layouts to Replicate(). But for the default case it should be None.
For CP, I think we need to manually change https://fburl.com/v2rn2s48.
For FSDP, not sure how the sharding info is specified today but looks like it's already handled?
will add a comment for now
updated the loss graph, I also tested both cases with vllm inference as well, and the text output is the same. |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you!
Add support to indexing rope cache using
position_ids, this might be needed duringposition_idsinto transformer forwardTest:

running dpskv3 16b base
also tested in https://github.com/wwwjn/torchtitan/pull/1/files when passing position_ids
