Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions scripts/performance/configs/llama3/llama3_llm_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,46 @@ def llama3_8b_h100_sft_config(precision: str = "bf16") -> ConfigContainer:
return cfg


def llama3_70b_gb300_sft_config(precision: str = "bf16") -> ConfigContainer:
"""GB300, SFT config."""
if precision == "bf16":
base_cfg = base_cfgs.LLAMA3_70B_GB300_SFT_BF16_BASE_CONFIG
precision_config = get_precision_config(precision)
else:
base_cfg = base_cfgs.LLAMA3_70B_GB300_SFT_FP8_CS_BASE_CONFIG
if precision == "fp8_mx":
base_cfg = base_cfgs.LLAMA3_70B_GB300_SFT_FP8_MX_BASE_CONFIG
precision_config = get_precision_config(precision)

cfg = llama3_70b_finetune_config(
peft="none",
precision_config=precision_config,
packed_sequence=True,
seq_length=4096,
)
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

cfg.comm_overlap = CommOverlapConfig(
tp_comm_overlap=bool(cfg.model.tensor_model_parallel_size > 1),
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
)

# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
# This ensures consistent cu_seqlens tensor shapes across batches, which is required
# for CUDA graphs and avoids NaN issues in attention kernels.
cfg.dataset.packed_sequence_specs.pad_cu_seqlens = True
cfg.dataset.dataset_kwargs["pad_to_max_length"] = True

if precision == "fp8_mx": # keeping this eanbled causes NaN grad norm
cfg.comm_overlap.overlap_param_gather = False
cfg.ddp.overlap_param_gather = False
cfg.optimizer.overlap_param_gather = False

return cfg


def llama3_70b_gb200_sft_config(precision: str = "bf16") -> ConfigContainer:
"""GB200, SFT config."""
if precision == "bf16":
Expand Down Expand Up @@ -161,6 +201,41 @@ def llama3_70b_h100_sft_config(precision: str = "bf16") -> ConfigContainer:
return cfg


def llama3_70b_gb300_lora_config(precision: str = "bf16") -> ConfigContainer:
"""GB300, LORA config."""
if precision == "bf16":
base_cfg = base_cfgs.LLAMA3_70B_GB300_LORA_BF16_BASE_CONFIG
precision_config = get_precision_config(precision)
else:
base_cfg = base_cfgs.LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG
if precision == "fp8_mx":
base_cfg = base_cfgs.LLAMA3_70B_GB300_LORA_FP8_MX_BASE_CONFIG
precision_config = get_precision_config(precision)

cfg = llama3_70b_finetune_config(
peft="lora",
precision_config=precision_config,
packed_sequence=True,
seq_length=2048,
)
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
# This ensures consistent cu_seqlens tensor shapes across batches, which is required
# for CUDA graphs and avoids NaN issues in attention kernels.
cfg.dataset.packed_sequence_specs.pad_cu_seqlens = True
cfg.dataset.dataset_kwargs["pad_to_max_length"] = True

if precision == "fp8_mx": # keeping this eanbled causes NaN grad norm
if cfg.comm_overlap is not None and isinstance(cfg.comm_overlap, CommOverlapConfig):
cfg.comm_overlap.overlap_param_gather = False
cfg.ddp.overlap_param_gather = False
cfg.optimizer.overlap_param_gather = False

return cfg


def llama3_70b_gb200_lora_config(precision: str = "bf16") -> ConfigContainer:
"""GB200, LORA config."""
if precision == "bf16":
Expand Down
53 changes: 48 additions & 5 deletions scripts/performance/configs/llama3/workload_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
BASE_LLAMA3_70B_CONFIG,
pipeline_model_parallel_size=4,
virtual_pipeline_model_parallel_size=5,
cuda_graph_impl="local",
cuda_graph_impl="none",
cuda_graph_scope="full_iteration",
)

Expand Down Expand Up @@ -87,10 +87,11 @@

LLAMA3_70B_GB200_NVFP4_BASE_CONFIG = replace(
BASE_LLAMA3_70B_CONFIG,
tensor_model_parallel_size=2,
pipeline_model_parallel_size=4,
virtual_pipeline_model_parallel_size=5,
context_parallel_size=2,
cuda_graph_impl="local",
context_parallel_size=1,
cuda_graph_impl="none",
cuda_graph_scope="full_iteration",
)

Expand Down Expand Up @@ -122,7 +123,8 @@

LLAMA3_70B_B200_NVFP4_BASE_CONFIG = replace(
BASE_LLAMA3_70B_CONFIG,
context_parallel_size=2,
tensor_model_parallel_size=2,
context_parallel_size=1,
pipeline_model_parallel_size=4,
virtual_pipeline_model_parallel_size=5,
)
Expand Down Expand Up @@ -191,7 +193,7 @@
LLAMA3_8B_GB200_NVFP4_BASE_CONFIG = replace(
BASE_LLAMA3_8B_CONFIG,
micro_batch_size=4,
cuda_graph_impl="local",
cuda_graph_impl="none",
cuda_graph_scope="full_iteration",
)

Expand Down Expand Up @@ -264,6 +266,24 @@
)


LLAMA3_70B_GB300_SFT_BASE_CONFIG = replace(
BASE_LLAMA3_70B_CONFIG,
num_gpus=32,
peft="none",
tensor_model_parallel_size=1,
pipeline_model_parallel_size=4,
virtual_pipeline_model_parallel_size=5,
micro_batch_size=1,
global_batch_size=32,
cuda_graph_impl="transformer_engine",
cuda_graph_scope="mlp",
)

LLAMA3_70B_GB300_SFT_BF16_BASE_CONFIG = LLAMA3_70B_GB300_SFT_BASE_CONFIG
LLAMA3_70B_GB300_SFT_FP8_CS_BASE_CONFIG = LLAMA3_70B_GB300_SFT_BASE_CONFIG
LLAMA3_70B_GB300_SFT_FP8_MX_BASE_CONFIG = LLAMA3_70B_GB300_SFT_BASE_CONFIG


LLAMA3_70B_GB200_SFT_BASE_CONFIG = replace(
BASE_LLAMA3_70B_CONFIG,
num_gpus=32,
Expand Down Expand Up @@ -297,6 +317,23 @@
LLAMA3_70B_H100_SFT_FP8_CS_BASE_CONFIG = LLAMA3_70B_H100_SFT_BASE_CONFIG


LLAMA3_70B_GB300_LORA_BASE_CONFIG = replace(
BASE_LLAMA3_70B_CONFIG,
num_gpus=8,
peft="lora",
# pipeline_model_parallel_size=4,
# virtual_pipeline_model_parallel_size=20,
micro_batch_size=1,
global_batch_size=64,
cuda_graph_impl="transformer_engine",
cuda_graph_scope="mlp",
)

LLAMA3_70B_GB300_LORA_BF16_BASE_CONFIG = LLAMA3_70B_GB300_LORA_BASE_CONFIG
LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG = LLAMA3_70B_GB300_LORA_BASE_CONFIG
LLAMA3_70B_GB300_LORA_FP8_MX_BASE_CONFIG = LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG


LLAMA3_70B_GB200_LORA_BASE_CONFIG = replace(
BASE_LLAMA3_70B_CONFIG,
num_gpus=8,
Expand Down Expand Up @@ -367,11 +404,17 @@
"LLAMA3_8B_H100_SFT_BF16_BASE_CONFIG",
"LLAMA3_8B_H100_SFT_FP8_CS_BASE_CONFIG",
"LLAMA3_8B_H100_SFT_FP8_MX_BASE_CONFIG",
"LLAMA3_70B_GB300_SFT_BF16_BASE_CONFIG",
"LLAMA3_70B_GB300_SFT_FP8_CS_BASE_CONFIG",
"LLAMA3_70B_GB300_SFT_FP8_MX_BASE_CONFIG",
"LLAMA3_70B_GB200_SFT_BF16_BASE_CONFIG",
"LLAMA3_70B_GB200_SFT_FP8_CS_BASE_CONFIG",
"LLAMA3_70B_GB200_SFT_FP8_MX_BASE_CONFIG",
"LLAMA3_70B_H100_SFT_BF16_BASE_CONFIG",
"LLAMA3_70B_H100_SFT_FP8_CS_BASE_CONFIG",
"LLAMA3_70B_GB300_LORA_BF16_BASE_CONFIG",
"LLAMA3_70B_GB300_LORA_FP8_CS_BASE_CONFIG",
"LLAMA3_70B_GB300_LORA_FP8_MX_BASE_CONFIG",
"LLAMA3_70B_GB200_LORA_BF16_BASE_CONFIG",
"LLAMA3_70B_GB200_LORA_FP8_CS_BASE_CONFIG",
"LLAMA3_70B_GB200_LORA_FP8_MX_BASE_CONFIG",
Expand Down
20 changes: 20 additions & 0 deletions scripts/performance/configs/qwen3/qwen3_llm_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,26 @@ def qwen3_next_80b_a3b_gb200_config(precision: str = "bf16") -> ConfigContainer:
return cfg


def qwen3_next_80b_a3b_b200_config(precision: str = "bf16") -> ConfigContainer:
"""GB200, baseline config."""
if precision == "bf16":
base_cfg = base_cfgs.QWEN3_NEXT_80B_A3B_B200_BF16_BASE_CONFIG
precision_config = get_precision_config(precision)
else:
base_cfg = base_cfgs.QWEN3_NEXT_80B_A3B_B200_FP8_MX_BASE_CONFIG
precision_config = get_precision_config(precision)

cfg = qwen3_next_80b_a3b_pretrain_config(
mock=True,
precision_config=precision_config,
comm_overlap_config=CommOverlapConfig(tp_comm_overlap=True),
)
set_qwen3_next_common_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

return cfg


def qwen3_next_80b_a3b_gb300_config(precision: str = "bf16") -> ConfigContainer:
"""GB300, baseline config."""
if precision == "bf16":
Expand Down
14 changes: 14 additions & 0 deletions scripts/performance/configs/qwen3/workload_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@
micro_batch_size=1,
)

QWEN3_NEXT_80B_A3B_B200_FP8_MX_BASE_CONFIG = replace(
BASE_QWEN3_NEXT_80B_A3B_CONFIG,
num_gpus=64,
micro_batch_size=1,
)

QWEN3_NEXT_80B_A3B_B200_BF16_BASE_CONFIG = replace(
BASE_QWEN3_NEXT_80B_A3B_CONFIG,
num_gpus=64,
micro_batch_size=1,
)

QWEN3_NEXT_80B_A3B_GB300_FP8_MX_BASE_CONFIG = replace(
BASE_QWEN3_NEXT_80B_A3B_CONFIG,
num_gpus=64,
Expand Down Expand Up @@ -309,6 +321,8 @@
"QWEN3_30B_A3B_H100_FP8_CS_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_GB200_BF16_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_GB200_FP8_MX_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_B200_FP8_MX_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_B200_BF16_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_GB300_FP8_MX_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_GB300_BF16_BASE_CONFIG",
"QWEN3_NEXT_80B_A3B_H100_FP8_CS_BASE_CONFIG",
Expand Down
21 changes: 16 additions & 5 deletions src/megatron/bridge/data/datasets/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,11 +958,9 @@ def collate_fn(self, batch):
cu_seqlens_unpadded = self._collate_item(
cu_seqlens_unpadded, max_length=max(len(length) for length in cu_seqlens_unpadded) + 1, pad_id=-1
)
# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
# Pre-generate `cu_seqlens_argmin` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)
cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded)
cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True)

Expand All @@ -978,8 +976,21 @@ def collate_fn(self, batch):
safe_max_seqlen = max(dataset_max_seqlen, padding_gap)
max_seqlen = torch.IntTensor([safe_max_seqlen] * len(cu_seqlens))
else:
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)
# Compute max_seqlen excluding the padding segment (marked by -1 padding in cu_seqlens).
# The padding segment is added when the packed sequence doesn't fill max_length,
# and including it in max_seqlen can cause NaN in attention kernels.
max_seqlen_list = []
for i in range(len(cu_seqlens)):
# Find the valid entries (before -1 padding)
valid_idx = cu_seqlens_argmin[i].item()
valid_cu_seqlens = cu_seqlens[i, :valid_idx]
if len(valid_cu_seqlens) > 1:
seqlens_i = valid_cu_seqlens[1:] - valid_cu_seqlens[:-1]
max_seqlen_list.append(seqlens_i.max().item())
else:
# Fallback: use max_length if no valid sequences
max_seqlen_list.append(max_length)
max_seqlen = torch.IntTensor(max_seqlen_list).unsqueeze(1)
processed_batch.update(
{
"attention_mask": torch.LongTensor(
Expand Down