From 3e23ce506f62941924126dbabb7ce7988253b0a5 Mon Sep 17 00:00:00 2001 From: Malay Nagda Date: Mon, 24 Nov 2025 20:32:56 +0530 Subject: [PATCH 1/3] nvfp4 and qwen3 80b_a3b fixes Signed-off-by: Malay Nagda --- .../configs/llama3/workload_base_configs.py | 4 ++-- .../configs/qwen3/qwen3_llm_pretrain.py | 20 +++++++++++++++++++ .../configs/qwen3/workload_base_configs.py | 14 +++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/scripts/performance/configs/llama3/workload_base_configs.py b/scripts/performance/configs/llama3/workload_base_configs.py index 2603f7384..fdf5fe9b9 100644 --- a/scripts/performance/configs/llama3/workload_base_configs.py +++ b/scripts/performance/configs/llama3/workload_base_configs.py @@ -58,7 +58,7 @@ BASE_LLAMA3_70B_CONFIG, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=5, - cuda_graph_impl="local", + cuda_graph_impl="none", cuda_graph_scope="full_iteration", ) @@ -191,7 +191,7 @@ LLAMA3_8B_GB200_NVFP4_BASE_CONFIG = replace( BASE_LLAMA3_8B_CONFIG, micro_batch_size=4, - cuda_graph_impl="local", + cuda_graph_impl="none", cuda_graph_scope="full_iteration", ) diff --git a/scripts/performance/configs/qwen3/qwen3_llm_pretrain.py b/scripts/performance/configs/qwen3/qwen3_llm_pretrain.py index 88ea6bd1c..fee00002d 100644 --- a/scripts/performance/configs/qwen3/qwen3_llm_pretrain.py +++ b/scripts/performance/configs/qwen3/qwen3_llm_pretrain.py @@ -269,6 +269,26 @@ def qwen3_next_80b_a3b_gb200_config(precision: str = "bf16") -> ConfigContainer: return cfg +def qwen3_next_80b_a3b_b200_config(precision: str = "bf16") -> ConfigContainer: + """GB200, baseline config.""" + if precision == "bf16": + base_cfg = base_cfgs.QWEN3_NEXT_80B_A3B_B200_BF16_BASE_CONFIG + precision_config = get_precision_config(precision) + else: + base_cfg = base_cfgs.QWEN3_NEXT_80B_A3B_B200_FP8_MX_BASE_CONFIG + precision_config = get_precision_config(precision) + + cfg = qwen3_next_80b_a3b_pretrain_config( + mock=True, + precision_config=precision_config, + comm_overlap_config=CommOverlapConfig(tp_comm_overlap=True), + ) + set_qwen3_next_common_configs(cfg) + set_workload_base_configs(cfg, base_cfg) + + return cfg + + def qwen3_next_80b_a3b_gb300_config(precision: str = "bf16") -> ConfigContainer: """GB300, baseline config.""" if precision == "bf16": diff --git a/scripts/performance/configs/qwen3/workload_base_configs.py b/scripts/performance/configs/qwen3/workload_base_configs.py index e1065e727..5aaa51927 100644 --- a/scripts/performance/configs/qwen3/workload_base_configs.py +++ b/scripts/performance/configs/qwen3/workload_base_configs.py @@ -253,6 +253,18 @@ micro_batch_size=1, ) +QWEN3_NEXT_80B_A3B_B200_FP8_MX_BASE_CONFIG = replace( + BASE_QWEN3_NEXT_80B_A3B_CONFIG, + num_gpus=64, + micro_batch_size=1, +) + +QWEN3_NEXT_80B_A3B_B200_BF16_BASE_CONFIG = replace( + BASE_QWEN3_NEXT_80B_A3B_CONFIG, + num_gpus=64, + micro_batch_size=1, +) + QWEN3_NEXT_80B_A3B_GB300_FP8_MX_BASE_CONFIG = replace( BASE_QWEN3_NEXT_80B_A3B_CONFIG, num_gpus=64, @@ -309,6 +321,8 @@ "QWEN3_30B_A3B_H100_FP8_CS_BASE_CONFIG", "QWEN3_NEXT_80B_A3B_GB200_BF16_BASE_CONFIG", "QWEN3_NEXT_80B_A3B_GB200_FP8_MX_BASE_CONFIG", + "QWEN3_NEXT_80B_A3B_B200_FP8_MX_BASE_CONFIG", + "QWEN3_NEXT_80B_A3B_B200_BF16_BASE_CONFIG", "QWEN3_NEXT_80B_A3B_GB300_FP8_MX_BASE_CONFIG", "QWEN3_NEXT_80B_A3B_GB300_BF16_BASE_CONFIG", "QWEN3_NEXT_80B_A3B_H100_FP8_CS_BASE_CONFIG", From c84d76736f89c9ae6578fd753cad3e54caf8c6e5 Mon Sep 17 00:00:00 2001 From: Malay Nagda Date: Tue, 25 Nov 2025 14:56:36 +0530 Subject: [PATCH 2/3] peft gb300, nvfp4 tp1->tp2 Signed-off-by: Malay Nagda --- .../configs/llama3/llama3_llm_finetune.py | 63 +++++++++++++++++++ .../configs/llama3/workload_base_configs.py | 49 ++++++++++++++- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/scripts/performance/configs/llama3/llama3_llm_finetune.py b/scripts/performance/configs/llama3/llama3_llm_finetune.py index 30e69c164..81454dcd6 100644 --- a/scripts/performance/configs/llama3/llama3_llm_finetune.py +++ b/scripts/performance/configs/llama3/llama3_llm_finetune.py @@ -98,6 +98,40 @@ def llama3_8b_h100_sft_config(precision: str = "bf16") -> ConfigContainer: return cfg +def llama3_70b_gb300_sft_config(precision: str = "bf16") -> ConfigContainer: + """GB300, SFT config.""" + if precision == "bf16": + base_cfg = base_cfgs.LLAMA3_70B_GB300_SFT_BF16_BASE_CONFIG + precision_config = get_precision_config(precision) + else: + base_cfg = base_cfgs.LLAMA3_70B_GB300_SFT_FP8_CS_BASE_CONFIG + if precision == "fp8_mx": + base_cfg = base_cfgs.LLAMA3_70B_GB300_SFT_FP8_MX_BASE_CONFIG + precision_config = get_precision_config(precision) + + cfg = llama3_70b_finetune_config( + peft="none", + precision_config=precision_config, + packed_sequence=True, + seq_length=4096, + ) + set_llama3_common_peft_configs(cfg) + set_workload_base_configs(cfg, base_cfg) + + cfg.comm_overlap = CommOverlapConfig( + tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1), + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + + if precision == "fp8_mx": # keeping this eanbled causes NaN grad norm + cfg.comm_overlap.overlap_param_gather = False + cfg.ddp.overlap_param_gather = False + cfg.optimizer.overlap_param_gather = False + + return cfg + + def llama3_70b_gb200_sft_config(precision: str = "bf16") -> ConfigContainer: """GB200, SFT config.""" if precision == "bf16": @@ -161,6 +195,35 @@ def llama3_70b_h100_sft_config(precision: str = "bf16") -> ConfigContainer: return cfg +def llama3_70b_gb300_lora_config(precision: str = "bf16") -> ConfigContainer: + """GB300, LORA config.""" + if precision == "bf16": + base_cfg = base_cfgs.LLAMA3_70B_GB300_LORA_BF16_BASE_CONFIG + precision_config = get_precision_config(precision) + else: + base_cfg = base_cfgs.LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG + if precision == "fp8_mx": + base_cfg = base_cfgs.LLAMA3_70B_GB300_LORA_FP8_MX_BASE_CONFIG + precision_config = get_precision_config(precision) + + cfg = llama3_70b_finetune_config( + peft="lora", + precision_config=precision_config, + packed_sequence=True, + seq_length=2048, + ) + set_llama3_common_peft_configs(cfg) + set_workload_base_configs(cfg, base_cfg) + + if precision == "fp8_mx": # keeping this eanbled causes NaN grad norm + if cfg.comm_overlap is not None and isinstance(cfg.comm_overlap, CommOverlapConfig): + cfg.comm_overlap.overlap_param_gather = False + cfg.ddp.overlap_param_gather = False + cfg.optimizer.overlap_param_gather = False + + return cfg + + def llama3_70b_gb200_lora_config(precision: str = "bf16") -> ConfigContainer: """GB200, LORA config.""" if precision == "bf16": diff --git a/scripts/performance/configs/llama3/workload_base_configs.py b/scripts/performance/configs/llama3/workload_base_configs.py index fdf5fe9b9..cc85b8568 100644 --- a/scripts/performance/configs/llama3/workload_base_configs.py +++ b/scripts/performance/configs/llama3/workload_base_configs.py @@ -87,10 +87,11 @@ LLAMA3_70B_GB200_NVFP4_BASE_CONFIG = replace( BASE_LLAMA3_70B_CONFIG, + tensor_model_parallel_size=2, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=5, - context_parallel_size=2, - cuda_graph_impl="local", + context_parallel_size=1, + cuda_graph_impl="none", cuda_graph_scope="full_iteration", ) @@ -122,7 +123,8 @@ LLAMA3_70B_B200_NVFP4_BASE_CONFIG = replace( BASE_LLAMA3_70B_CONFIG, - context_parallel_size=2, + tensor_model_parallel_size=2, + context_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=5, ) @@ -264,6 +266,24 @@ ) +LLAMA3_70B_GB300_SFT_BASE_CONFIG = replace( + BASE_LLAMA3_70B_CONFIG, + num_gpus=32, + peft="none", + tensor_model_parallel_size=2, + pipeline_model_parallel_size=4, + virtual_pipeline_model_parallel_size=5, + micro_batch_size=1, + global_batch_size=32, + cuda_graph_impl="transformer_engine", + cuda_graph_scope="mlp", +) + +LLAMA3_70B_GB300_SFT_BF16_BASE_CONFIG = LLAMA3_70B_GB300_SFT_BASE_CONFIG +LLAMA3_70B_GB300_SFT_FP8_CS_BASE_CONFIG = LLAMA3_70B_GB300_SFT_BASE_CONFIG +LLAMA3_70B_GB300_SFT_FP8_MX_BASE_CONFIG = LLAMA3_70B_GB300_SFT_BASE_CONFIG + + LLAMA3_70B_GB200_SFT_BASE_CONFIG = replace( BASE_LLAMA3_70B_CONFIG, num_gpus=32, @@ -297,6 +317,23 @@ LLAMA3_70B_H100_SFT_FP8_CS_BASE_CONFIG = LLAMA3_70B_H100_SFT_BASE_CONFIG +LLAMA3_70B_GB300_LORA_BASE_CONFIG = replace( + BASE_LLAMA3_70B_CONFIG, + num_gpus=8, + peft="lora", + pipeline_model_parallel_size=4, + virtual_pipeline_model_parallel_size=20, + micro_batch_size=1, + global_batch_size=64, + cuda_graph_impl="transformer_engine", + cuda_graph_scope="mlp", +) + +LLAMA3_70B_GB300_LORA_BF16_BASE_CONFIG = LLAMA3_70B_GB300_LORA_BASE_CONFIG +LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG = LLAMA3_70B_GB300_LORA_BASE_CONFIG +LLAMA3_70B_GB300_LORA_FP8_MX_BASE_CONFIG = LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG + + LLAMA3_70B_GB200_LORA_BASE_CONFIG = replace( BASE_LLAMA3_70B_CONFIG, num_gpus=8, @@ -367,11 +404,17 @@ "LLAMA3_8B_H100_SFT_BF16_BASE_CONFIG", "LLAMA3_8B_H100_SFT_FP8_CS_BASE_CONFIG", "LLAMA3_8B_H100_SFT_FP8_MX_BASE_CONFIG", + "LLAMA3_70B_GB300_SFT_BF16_BASE_CONFIG", + "LLAMA3_70B_GB300_SFT_FP8_CS_BASE_CONFIG", + "LLAMA3_70B_GB300_SFT_FP8_MX_BASE_CONFIG", "LLAMA3_70B_GB200_SFT_BF16_BASE_CONFIG", "LLAMA3_70B_GB200_SFT_FP8_CS_BASE_CONFIG", "LLAMA3_70B_GB200_SFT_FP8_MX_BASE_CONFIG", "LLAMA3_70B_H100_SFT_BF16_BASE_CONFIG", "LLAMA3_70B_H100_SFT_FP8_CS_BASE_CONFIG", + "LLAMA3_70B_GB300_LORA_BF16_BASE_CONFIG", + "LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG", + "LLAMA3_70B_GB300_LORA_FP8_MX_BASE_CONFIG", "LLAMA3_70B_GB200_LORA_BF16_BASE_CONFIG", "LLAMA3_70B_GB200_LORA_FP8_CS_BASE_CONFIG", "LLAMA3_70B_GB200_LORA_FP8_MX_BASE_CONFIG", From d088e9ea4248d1c8f72c5b691a455b89919b5aa3 Mon Sep 17 00:00:00 2001 From: Malay Nagda Date: Fri, 28 Nov 2025 15:21:35 +0530 Subject: [PATCH 3/3] try fix nan grad Signed-off-by: Malay Nagda --- .../configs/llama3/llama3_llm_finetune.py | 12 +++++++++++ .../configs/llama3/workload_base_configs.py | 6 +++--- src/megatron/bridge/data/datasets/sft.py | 21 ++++++++++++++----- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/scripts/performance/configs/llama3/llama3_llm_finetune.py b/scripts/performance/configs/llama3/llama3_llm_finetune.py index 81454dcd6..d82821653 100644 --- a/scripts/performance/configs/llama3/llama3_llm_finetune.py +++ b/scripts/performance/configs/llama3/llama3_llm_finetune.py @@ -124,6 +124,12 @@ def llama3_70b_gb300_sft_config(precision: str = "bf16") -> ConfigContainer: wgrad_deferral_limit=22, ) + # Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences. + # This ensures consistent cu_seqlens tensor shapes across batches, which is required + # for CUDA graphs and avoids NaN issues in attention kernels. + cfg.dataset.packed_sequence_specs.pad_cu_seqlens = True + cfg.dataset.dataset_kwargs["pad_to_max_length"] = True + if precision == "fp8_mx": # keeping this eanbled causes NaN grad norm cfg.comm_overlap.overlap_param_gather = False cfg.ddp.overlap_param_gather = False @@ -215,6 +221,12 @@ def llama3_70b_gb300_lora_config(precision: str = "bf16") -> ConfigContainer: set_llama3_common_peft_configs(cfg) set_workload_base_configs(cfg, base_cfg) + # Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences. + # This ensures consistent cu_seqlens tensor shapes across batches, which is required + # for CUDA graphs and avoids NaN issues in attention kernels. + cfg.dataset.packed_sequence_specs.pad_cu_seqlens = True + cfg.dataset.dataset_kwargs["pad_to_max_length"] = True + if precision == "fp8_mx": # keeping this eanbled causes NaN grad norm if cfg.comm_overlap is not None and isinstance(cfg.comm_overlap, CommOverlapConfig): cfg.comm_overlap.overlap_param_gather = False diff --git a/scripts/performance/configs/llama3/workload_base_configs.py b/scripts/performance/configs/llama3/workload_base_configs.py index cc85b8568..8d6197602 100644 --- a/scripts/performance/configs/llama3/workload_base_configs.py +++ b/scripts/performance/configs/llama3/workload_base_configs.py @@ -270,7 +270,7 @@ BASE_LLAMA3_70B_CONFIG, num_gpus=32, peft="none", - tensor_model_parallel_size=2, + tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=5, micro_batch_size=1, @@ -321,8 +321,8 @@ BASE_LLAMA3_70B_CONFIG, num_gpus=8, peft="lora", - pipeline_model_parallel_size=4, - virtual_pipeline_model_parallel_size=20, + # pipeline_model_parallel_size=4, + # virtual_pipeline_model_parallel_size=20, micro_batch_size=1, global_batch_size=64, cuda_graph_impl="transformer_engine", diff --git a/src/megatron/bridge/data/datasets/sft.py b/src/megatron/bridge/data/datasets/sft.py index fa94cdec8..ebad1720b 100644 --- a/src/megatron/bridge/data/datasets/sft.py +++ b/src/megatron/bridge/data/datasets/sft.py @@ -958,11 +958,9 @@ def collate_fn(self, batch): cu_seqlens_unpadded = self._collate_item( cu_seqlens_unpadded, max_length=max(len(length) for length in cu_seqlens_unpadded) + 1, pad_id=-1 ) - # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. + # Pre-generate `cu_seqlens_argmin` as CPU tensor to avoid device-to-host copies. cu_seqlens = torch.IntTensor(cu_seqlens) cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) - seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] - max_seqlen, _ = seqlens.max(dim=1, keepdim=True) cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded) cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True) @@ -978,8 +976,21 @@ def collate_fn(self, batch): safe_max_seqlen = max(dataset_max_seqlen, padding_gap) max_seqlen = torch.IntTensor([safe_max_seqlen] * len(cu_seqlens)) else: - seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] - max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + # Compute max_seqlen excluding the padding segment (marked by -1 padding in cu_seqlens). + # The padding segment is added when the packed sequence doesn't fill max_length, + # and including it in max_seqlen can cause NaN in attention kernels. + max_seqlen_list = [] + for i in range(len(cu_seqlens)): + # Find the valid entries (before -1 padding) + valid_idx = cu_seqlens_argmin[i].item() + valid_cu_seqlens = cu_seqlens[i, :valid_idx] + if len(valid_cu_seqlens) > 1: + seqlens_i = valid_cu_seqlens[1:] - valid_cu_seqlens[:-1] + max_seqlen_list.append(seqlens_i.max().item()) + else: + # Fallback: use max_length if no valid sequences + max_seqlen_list.append(max_length) + max_seqlen = torch.IntTensor(max_seqlen_list).unsqueeze(1) processed_batch.update( { "attention_mask": torch.LongTensor(