-
Notifications
You must be signed in to change notification settings - Fork 631
[simplefsdp] fix & enable DSV3 manual bucketing #2080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f931aa9 to
88b700b
Compare
| ), | ||
| "16B": DeepSeekV3ModelArgs( | ||
| vocab_size=102400, | ||
| dim=2048, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l Should we have another config to allow users to turn on/off flexattention? Currently, flexattention doesn't work well with AC here. cc. @soulitzer for AC issue follow up!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what was the symptom?
also if it doesn't work why do we add an entry for it -- is it for repro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's 16B_flexatten doesn't work. But in current DSV3 implementation, flexatten by default in turned on. I want to have a model config that, by default, turns off flex attention.
| for m in modules: | ||
| if isinstance(m, list): | ||
| result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) | ||
| if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does the syntax mean -- assigning to fqn_list and check not None? It feels a bit unusual to read.
Also please add a comment on why we need this check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, added a comment for it.
| ), | ||
| "16B": DeepSeekV3ModelArgs( | ||
| vocab_size=102400, | ||
| dim=2048, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what was the symptom?
also if it doesn't work why do we add an entry for it -- is it for repro?
88b700b to
35ad842
Compare
35ad842 to
3b0fdda
Compare
| VIEW_OPS = { | ||
| torch.ops.aten.slice.Tensor, | ||
| torch.ops.aten.view.default, | ||
| torch.ops.aten.reshape.default, | ||
| torch.ops.aten.transpose.int, | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh Following today's discussion, I updated reshard_after_fwd to enforce all VIEW_OPS after wait are recomputed.
In this fx-graph in tlparse (link), view_63-view_65 + transpose_8 is enforced to be recompute. Thus, I can successfully get correct reshard_after_fwd semantics.
However, in bwd, the _grouped_mm is recomputed, seems because we enforce this region to be MUST_RECOMPUTE. I feel like I'm in a rabbit hole that, if not RECOMPUTE transpose_8, I will not get correct FSDP semantics. But if I RECOMPUTE transpose_8, the follow up _grouped_mm is recomputed.
Wonder if you think I should fix this from simplefsdp side or partitioner side? 🤔
_to_copy_32: "bf16[1, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten._to_copy.default(view_58, dtype = torch.bfloat16); view_58 = None
all_gather_into_tensor_19: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(_to_copy_32, 4, '1'); _to_copy_32 = None
wait_tensor_22: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19); all_gather_into_tensor_19 = None
view_63: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(wait_tensor_22, [4, 256, 256]); wait_tensor_22 = None
view_64: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(view_63, [4, 256, 256]); view_63 = None
view_65: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(view_64, [4, 256, 256]); view_64 = None
transpose_8: "bf16[4, 256, 256][65536, 1, 256]cuda:0" = torch.ops.aten.transpose.int(view_65, -2, -1); view_65 = None
_grouped_mm: "bf16[8*(((u2 + u3 + 39)//8)), 256][256, 1]cuda:0" = torch.ops.aten._grouped_mm.default(index_1, transpose_8, cumsum_2); transpose_8 = None
Validate DSV3 manual bucketing when EP/TP are enable. Tested on DSV3-16B model. Dependent on Pytorch PR
(Single Node: BS = 1)