Skip to content

Conversation

@ruisizhang123
Copy link
Member

@ruisizhang123 ruisizhang123 commented Nov 24, 2025

Validate DSV3 manual bucketing when EP/TP are enable. Tested on DSV3-16B model. Dependent on Pytorch PR

(Single Node: BS = 1)

Node Method Parallelism Memory TPS Trace
1-Node (8H100) SimpleFSDP (aot_eager) FSDP=4 EP=2 51.11GiB(53.80%) 5,136 Link
1-Node (8H100) FSDP2-eager FSDP=4 EP=2 59.54GiB(62.68%) 5,942 Link
1-Node (8H100) SimpleFSDP (aot_eager) FSDP=2 TP=2 EP=2 42.21GiB(44.43%) 2,285 Link
1-Node (8H100) FSDP2-eager FSDP=2 TP=2 EP=2 45.41GiB(47.80%) 2,349 Link
8-Node (64H100) SimpleFSDP (aot_eager) FSDP=4 EP=2 Link
8-Node (64H100) FSDP2-eager FSDP=4 EP=2 Link
8-Node (64H100) SimpleFSDP (aot_eager) FSDP=2 TP=2 EP=2 Link
9-Node (64H100) FSDP2-eager FSDP=2 TP=2 EP=2 Link
  1. Example Trace
Screenshot 2025-12-10 at 7 51 23 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 24, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft November 24, 2025 17:19
@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch from f931aa9 to 88b700b Compare December 11, 2025 05:23
@ruisizhang123 ruisizhang123 marked this pull request as ready for review December 11, 2025 05:24
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l Should we have another config to allow users to turn on/off flexattention? Currently, flexattention doesn't work well with AC here. cc. @soulitzer for AC issue follow up!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the symptom?

also if it doesn't work why do we add an entry for it -- is it for repro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's 16B_flexatten doesn't work. But in current DSV3 implementation, flexatten by default in turned on. I want to have a model config that, by default, turns off flex attention.

for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the syntax mean -- assigning to fqn_list and check not None? It feels a bit unusual to read.

Also please add a comment on why we need this check

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, added a comment for it.

),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the symptom?

also if it doesn't work why do we add an entry for it -- is it for repro?

@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch from 88b700b to 35ad842 Compare December 12, 2025 01:08
@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch from 35ad842 to 3b0fdda Compare December 12, 2025 01:12
Comment on lines +69 to +75
VIEW_OPS = {
torch.ops.aten.slice.Tensor,
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
torch.ops.aten.transpose.int,
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Following today's discussion, I updated reshard_after_fwd to enforce all VIEW_OPS after wait are recomputed.

In this fx-graph in tlparse (link), view_63-view_65 + transpose_8 is enforced to be recompute. Thus, I can successfully get correct reshard_after_fwd semantics.

However, in bwd, the _grouped_mm is recomputed, seems because we enforce this region to be MUST_RECOMPUTE. I feel like I'm in a rabbit hole that, if not RECOMPUTE transpose_8, I will not get correct FSDP semantics. But if I RECOMPUTE transpose_8, the follow up _grouped_mm is recomputed.

Wonder if you think I should fix this from simplefsdp side or partitioner side? 🤔

 _to_copy_32: "bf16[1, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten._to_copy.default(view_58, dtype = torch.bfloat16);  view_58 = None
        
all_gather_into_tensor_19: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(_to_copy_32, 4, '1');  _to_copy_32 = None

wait_tensor_22: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19);  all_gather_into_tensor_19 = None
        
view_63: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(wait_tensor_22, [4, 256, 256]);  wait_tensor_22 = None
        
view_64: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(view_63, [4, 256, 256]);  view_63 = None
        
view_65: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(view_64, [4, 256, 256]);  view_64 = None

transpose_8: "bf16[4, 256, 256][65536, 1, 256]cuda:0" = torch.ops.aten.transpose.int(view_65, -2, -1);  view_65 = None
        
_grouped_mm: "bf16[8*(((u2 + u3 + 39)//8)), 256][256, 1]cuda:0" = torch.ops.aten._grouped_mm.default(index_1, transpose_8, cumsum_2);  transpose_8 = None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants