Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def get_model(
# GPU allocation.
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
# in the fully_shard function of FSDP2 instead.
if not (use_torch_fsdp2 and model_config.use_cpu_initialization) and not model_config.init_model_with_meta_device:
if not use_torch_fsdp2 and not model_config.use_cpu_initialization and not model_config.init_model_with_meta_device:
for model_module in model:
model_module.cuda(torch.cuda.current_device())

Expand Down
36 changes: 35 additions & 1 deletion src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,35 @@ def get_checkpoint_version() -> Optional[float]:
return _CHECKPOINT_VERSION


def delete_extra_state(state_dict):
"""Delete all extra state keys from the model state dictionary.

This function removes all keys containing '_extra_state' from the model
portion of the state dictionary. This is useful for cleaning up corrupted
or problematic extra state that can cause issues during model loading.

Args:
state_dict: The state dictionary. Can be either:
- A full checkpoint dict with a "model" key, or
- A model state dict directly

Returns:
The modified state dictionary with extra state keys removed.
"""
# Handle both cases: full checkpoint dict with "model" key or direct model state dict
if "model" in state_dict:
# Full checkpoint dict case
target_dict = state_dict["model"]
else:
# Direct model state dict case
target_dict = state_dict

for key in list(target_dict.keys()):
if "_extra_state" in key:
del target_dict[key]
return state_dict


def _get_checkpoint_format(checkpoint_path: str) -> str:
"""Determine the checkpoint format by examining the checkpoint directory.

Expand Down Expand Up @@ -834,6 +863,7 @@ def _generate_model_state_dict(
else: # fsdp_dtensor and other formats
state_dict["model%d" % i] = model[i].state_dict_for_save_checkpoint()

delete_extra_state(state_dict)
return state_dict


Expand Down Expand Up @@ -1048,11 +1078,15 @@ def _load_model_state_dict(module: torch.nn.Module, state_dict: dict[str, Any],
"""Helper function to load state dict with fallback for missing extra states."""
try:
module.load_state_dict(state_dict, strict=strict)
except Exception:
except Exception as e:
if strict:
# Fallback support for backward compatibility breaking changes in TransformerEngine
print(f"Warning: Exception during strict loading: {e}")
load_return = module.load_state_dict(state_dict, strict=False)
print(f"load_return: {load_return}")
else:
# Re-raise if we were already in non-strict mode
raise


def _load_checkpoint_from_path(
Expand Down
12 changes: 12 additions & 0 deletions src/megatron/bridge/training/model_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,18 @@ def load_megatron_model(
"""

model_cfg, mlm_args = load_model_config(checkpoint_path)
# If in single GPU environment, reset additional parallel settings
if use_cpu_init or not skip_temp_dist_context:
model_cfg.tensor_model_parallel_size = 1
model_cfg.pipeline_model_parallel_size = 1
model_cfg.context_parallel_size = 1
model_cfg.expert_model_parallel_size = 1
model_cfg.expert_tensor_parallel_size = 1
model_cfg.moe_extended_tp = False
model_cfg.sequence_parallel = False
model_cfg.virtual_pipeline_model_parallel_size = None
model_cfg.hierarchical_context_parallel_sizes = None

return build_and_load_model(
checkpoint_path, model_cfg, model_type, mlm_args, return_state_dict, use_cpu_init, skip_temp_dist_context
)
Expand Down