Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,46 @@

Run Script Examples:
# Regular LLM comparison between HF and Megatron models:
python examples/models/compare_hf_and_megatron/compare.py \
python examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen3-1.7B" \
--prompt "Hello, how are you?"


# Vision-language comparison with image from URL:
python examples/models/compare_hf_and_megatron/compare.py \
python examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen2.5-VL-3B-Instruct" \
--model_class "Qwen2_5_VLForConditionalGeneration" \
--image_path "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" \
--prompt "Describe this image."

# Vision-language comparison with local image:
python examples/models/compare_hf_and_megatron/compare.py \
python examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen2.5-VL-3B-Instruct" \
--model_class "Qwen2_5_VLForConditionalGeneration" \
--image_path "/path/to/local/image.jpg" \
--prompt "What do you see in this image?"

# Multi-GPU comparison with tensor parallelism (regular LLM):
torchrun --nproc_per_node=2 examples/models/compare_hf_and_megatron/compare.py \
torchrun --nproc_per_node=2 examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen3-1.7B" \
--prompt "Hello world" \
--tp 2

# Pipeline parallel comparison (VL model):
torchrun --nproc_per_node=2 examples/models/compare_hf_and_megatron/compare.py \
torchrun --nproc_per_node=2 examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen2.5-VL-3B-Instruct" \
--model_class "Qwen2_5_VLForConditionalGeneration" \
--prompt "Hello world" \
--pp 2

# Compare with pre-converted Megatron checkpoint:
python examples/models/compare_hf_and_megatron/compare.py \
python examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen3-1.7B" \
--megatron_model_path "/path/to/megatron/checkpoint" \
--prompt "Hello world"

# Enable debug hooks to inspect forward pass intermediate results:
python examples/models/compare_hf_and_megatron/compare.py \
python examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path "Qwen/Qwen3-1.7B" \
--prompt "Hello world" \
--enable_debug_hooks
Expand Down Expand Up @@ -491,7 +491,16 @@ def _load_megatron_model(args):
model_provider.expert_tensor_parallel_size = etp
model_provider.pipeline_dtype = torch.bfloat16
model_provider.initialize_model_parallel(seed=0)
megatron_model = bridge.load_megatron_model(args.megatron_model_path, wrap_with_ddp=False)
megatron_model = bridge.load_megatron_model(
args.megatron_model_path,
mp_overrides={
"tensor_model_parallel_size": tp,
"pipeline_model_parallel_size": pp,
"expert_model_parallel_size": ep,
"expert_tensor_parallel_size": etp,
},
wrap_with_ddp=False,
)
else:
# Convert from HF to Megatron
bridge = AutoBridge.from_hf_pretrained(args.hf_model_path)
Expand Down
15 changes: 12 additions & 3 deletions examples/conversion/hf_megatron_roundtrip_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
in Megatron's native checkpoint format by specifying the `--megatron-save-path` argument.

Usage:
torchrun --nproc_per_node 1 examples/models/hf_megatron_roundtrip_multi_gpu.py
torchrun --nproc_per_node 1 examples/models/hf_megatron_roundtrip_multi_gpu.py --megatron-save-path ./megatron_checkpoint
torchrun --nproc_per_node 1 examples/conversion/hf_megatron_roundtrip_multi_gpu.py
torchrun --nproc_per_node 1 examples/conversion/hf_megatron_roundtrip_multi_gpu.py --megatron-save-path ./megatron_checkpoint
"""

import argparse
Expand Down Expand Up @@ -89,7 +89,16 @@ def main(
# Once all overrides are set, finalize the model provider to ensure the post initialization logic is run
model_provider.finalize()
model_provider.initialize_model_parallel(seed=0)
megatron_model = bridge.load_megatron_model(megatron_load_path, wrap_with_ddp=False)
megatron_model = bridge.load_megatron_model(
megatron_load_path,
mp_overrides={
"tensor_model_parallel_size": tp,
"pipeline_model_parallel_size": pp,
"expert_model_parallel_size": ep,
"expert_tensor_parallel_size": etp,
},
wrap_with_ddp=False,
)
megatron_model = [m.cuda() for m in megatron_model]

else:
Expand Down
15 changes: 12 additions & 3 deletions examples/conversion/hf_to_megatron_generate_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
"""
Example:
# Load from HuggingFace model:
python examples/models/hf_to_megatron_generate_text.py --hf_model_path="meta-llama/Llama-3.2-1B" --prompt="Hello, how are you?"
python examples/conversion/hf_to_megatron_generate_text.py --hf_model_path="meta-llama/Llama-3.2-1B" --prompt="Hello, how are you?"

# Load from Megatron checkpoint:
python examples/models/hf_to_megatron_generate_text.py --hf_model_path="meta-llama/Llama-3.2-1B" --megatron_model_path="/path/to/megatron/checkpoint" --prompt="Hello, how are you?"
python examples/conversion/hf_to_megatron_generate_text.py --hf_model_path="meta-llama/Llama-3.2-1B" --megatron_model_path="/path/to/megatron/checkpoint" --prompt="Hello, how are you?"
"""

import argparse
Expand Down Expand Up @@ -127,7 +127,16 @@ def main(args) -> None:
model_provider.initialize_model_parallel(seed=0)

# Load the Megatron model directly
model = bridge.load_megatron_model(args.megatron_model_path, wrap_with_ddp=False)
model = bridge.load_megatron_model(
args.megatron_model_path,
mp_overrides={
"tensor_model_parallel_size": tp,
"pipeline_model_parallel_size": pp,
"expert_model_parallel_size": ep,
"expert_tensor_parallel_size": etp,
},
wrap_with_ddp=False,
)

else:
# Load from HuggingFace and convert to Megatron
Expand Down
11 changes: 10 additions & 1 deletion examples/conversion/hf_to_megatron_generate_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,16 @@ def main(args) -> None:
model_provider.initialize_model_parallel(seed=0)

# Load the Megatron model directly
model = bridge.load_megatron_model(args.megatron_model_path, wrap_with_ddp=False)
model = bridge.load_megatron_model(
args.megatron_model_path,
mp_overrides={
"tensor_model_parallel_size": tp,
"pipeline_model_parallel_size": pp,
"expert_model_parallel_size": ep,
"expert_tensor_parallel_size": etp,
},
wrap_with_ddp=False,
)

else:
# Load from HuggingFace and convert to Megatron
Expand Down
26 changes: 16 additions & 10 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from typing import Any, Generic, Iterable, List, Optional, Type, TypeVar, Union

import torch.distributed
import torch.distributed as dist
import transformers
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig
Expand All @@ -35,7 +35,7 @@
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource
from megatron.bridge.models.model_provider import GetModelKwargs, ModelProviderMixin
from megatron.bridge.models.model_provider import GetModelKwargs, ModelParallelKwargs, ModelProviderMixin


MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule)
Expand Down Expand Up @@ -373,9 +373,9 @@ def save_hf_pretrained(self, model: list[MegatronModelT], path: str | Path, show
saves the configuration files, while weight saving is coordinated
across all ranks.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if dist.is_available() and dist.is_initialized():
# Distributed training, only rank 0 saves artifacts
if torch.distributed.get_rank() == 0:
if dist.get_rank() == 0:
self.hf_pretrained.save_artifacts(path)
else:
# No distributed training, save artifacts
Expand Down Expand Up @@ -416,8 +416,8 @@ def save_hf_weights(self, model: list[MegatronModelT], path: str | Path, show_pr
- Automatically handles model sharding for large models
- The saved weights can be loaded with HuggingFace's from_pretrained
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_available() and dist.is_initialized():
dist.barrier()
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
generator = model_bridge.stream_weights_megatron_to_hf(
dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress
Expand All @@ -433,8 +433,8 @@ def save_hf_weights(self, model: list[MegatronModelT], path: str | Path, show_pr
else:
raise ValueError("The state source is not a SafeTensorsStateSource, cannot save in streaming mode.")

if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
if dist.is_available() and dist.is_initialized():
dist.barrier()

def save_megatron_model(
self, model: list[MegatronModule], path: str | Path, hf_tokenizer_path: Optional[str | Path] = None
Expand Down Expand Up @@ -476,7 +476,9 @@ def save_megatron_model(
raise ImportError("megatron.bridge.training is not available.")
save_megatron_model(model, path, hf_tokenizer_path=hf_tokenizer_path)

def load_megatron_model(self, path: str | Path, **kwargs: Unpack[GetModelKwargs]) -> list[MegatronModelT]:
def load_megatron_model(
self, path: str | Path, *, mp_overrides: ModelParallelKwargs | None = None, **kwargs: Unpack[GetModelKwargs]
) -> list[MegatronModelT]:
"""
Load a Megatron model from a native Megatron checkpoint.

Expand All @@ -486,6 +488,7 @@ def load_megatron_model(self, path: str | Path, **kwargs: Unpack[GetModelKwargs]

Args:
path: Directory path where the Megatron checkpoint is stored
mp_overrides: Optional model-parallel overrides to apply to the loaded config.
**kwargs: Additional arguments passed to the model provider

Returns:
Expand Down Expand Up @@ -529,10 +532,13 @@ def get_iter_number(folder_name):
checkpoint_path = checkpoint_path / latest_iter.name
# else: checkpoint_path remains as the input path (no iter folders found)

skip_temp_dist_context = dist.is_available() and dist.is_initialized()
# Load the state dict
model = load_megatron_model(
str(checkpoint_path),
use_cpu_init=True,
use_cpu_init=(skip_temp_dist_context and dist.get_backend() == "gloo"),
skip_temp_dist_context=skip_temp_dist_context,
mp_overrides=mp_overrides,
)
return model if isinstance(model, list) else [model]

Expand Down
18 changes: 18 additions & 0 deletions src/megatron/bridge/models/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,24 @@ class GetModelKwargs(TypedDict, total=False):
post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None


class ModelParallelKwargs(TypedDict, total=False):
"""Model-parallel override kwargs.

Attributes map to `TransformerConfig`/provider fields that control parallelism.
Only provided values are applied as overrides.
"""

tensor_model_parallel_size: int
pipeline_model_parallel_size: int
context_parallel_size: int
expert_model_parallel_size: int
expert_tensor_parallel_size: int
moe_extended_tp: bool
sequence_parallel: bool
virtual_pipeline_model_parallel_size: int | None
hierarchical_context_parallel_sizes: list[int] | None


def get_model(
model_provider: ModelProviderMixin,
ddp_config: DistributedDataParallelConfig,
Expand Down
56 changes: 51 additions & 5 deletions src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,45 @@ def get_checkpoint_version() -> Optional[float]:
return _CHECKPOINT_VERSION


def delete_extra_state(state_dict):
"""Delete all extra state keys from the model state dictionary.

This function removes all keys containing '_extra_state' from the model
portion of the state dictionary. This is useful for cleaning up corrupted
or problematic extra state that can cause issues during model loading.

Args:
state_dict: The state dictionary. Can be either:
- A full checkpoint dict with a "model" key, or
- A model state dict directly

Returns:
The modified state dictionary with extra state keys removed.
"""
# Handle both cases: full checkpoint dict with "model" key or direct model state dict
if isinstance(state_dict, dict) and "model" in state_dict:
# Full checkpoint dict case
target_dict = state_dict["model"]
else:
# Direct model state dict case
target_dict = state_dict

# If target is not a mapping-like object, nothing to clean
if not hasattr(target_dict, "keys"):
return state_dict

# Some objects may implement keys() but not be directly iterable into a list (e.g., mocks)
try:
keys = list(target_dict.keys())
except Exception:
return state_dict

for key in keys:
if isinstance(key, str) and "_extra_state" in key:
del target_dict[key]
return state_dict


def _get_checkpoint_format(checkpoint_path: str) -> str:
"""Determine the checkpoint format by examining the checkpoint directory.

Expand Down Expand Up @@ -226,7 +265,7 @@ def read_metadata(tracker_filename: str) -> tuple[int, bool]:
# iteration across all ranks.
if iteration != max_iter:
rank = torch.distributed.get_rank()
print(
print_rank_0(
"WARNING: on rank {} found iteration {} in the "
"metadata while max iteration across the ranks "
"is {}, replacing it with max iteration.".format(rank, iteration, max_iter),
Expand Down Expand Up @@ -784,7 +823,7 @@ def maybe_save_dataloader_state(train_iterator: Any, iteration: int, dataloader_
return

dp_rank = mpu.get_data_parallel_rank()
print(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}")
print_rank_0(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}")
train_dataloader_state_dict = train_iterator.iterable.save_state()
# Get the base directory for the current iteration
iter_dir = get_checkpoint_name(dataloader_save_path, iteration)
Expand Down Expand Up @@ -976,6 +1015,9 @@ def _load_model_weights_from_checkpoint(
state_dict = dist_checkpointing.load(
sharded_state_dict, checkpoint_path, load_strategy, strict=dist_ckpt_strictness
)
# we keep weights only for bridge use, remove extra state
# because they are not needed and could cause unexpected issues.
delete_extra_state(state_dict)
if return_state_dict:
return state_dict

Expand Down Expand Up @@ -1048,11 +1090,15 @@ def _load_model_state_dict(module: torch.nn.Module, state_dict: dict[str, Any],
"""Helper function to load state dict with fallback for missing extra states."""
try:
module.load_state_dict(state_dict, strict=strict)
except Exception:
except Exception as e:
if strict:
# Fallback support for backward compatibility breaking changes in TransformerEngine
print_rank_0(f"Warning: Exception during strict loading: {e}")
load_return = module.load_state_dict(state_dict, strict=False)
print(f"load_return: {load_return}")
print_rank_0(f"load_return: {load_return}")
else:
# Re-raise if we were already in non-strict mode
raise


def _load_checkpoint_from_path(
Expand Down Expand Up @@ -1376,7 +1422,7 @@ def _load_checkpoint_from_path(
if "rerun_state_machine" in state_dict:
get_rerun_state_machine().load_state_dict(state_dict["rerun_state_machine"])
except Exception as e:
print(f"Unable to restore RerunMachine from checkpoint: {e}. Skipping.")
print_rank_0(f"Unable to restore RerunMachine from checkpoint: {e}. Skipping.")
sys.exit()

# Load RNG states
Expand Down
Loading
Loading