Skip to content
7 changes: 7 additions & 0 deletions flagscale/train/megatron/nemo_bridge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2025, BAAI. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nemo megatron-bridge supports pip install for usage, ref https://pypi.org/project/megatron-bridge/
please remove source codes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename flagscale/train/megatron/nemo_bridge to flagscale/train/megatron/bridge so that it matches the import pattern from megatron.bridge


"""Megatron Bridge - A component of the Megatron ecosystem."""

from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge

__all__ = ["AutoBridge"]
21 changes: 21 additions & 0 deletions flagscale/train/megatron/nemo_bridge/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025, BAAI. All rights reserved.

from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge
from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.nemo_bridge.models.conversion.param_mapping import (
AutoMapping,
QKVMapping,
)
from megatron.nemo_bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge
from megatron.nemo_bridge.models.qwen.qwen3_bridge import Qwen3Bridge
from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM

__all__ = [
"AutoBridge",
"MegatronModelBridge",
"QKVMapping",
"AutoMapping",
"DeepSeekV3Bridge",
"Qwen3Bridge",
"PreTrainedCausalLM",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) 2025, BAAI. All rights reserved.

from megatron.nemo_bridge.models.conversion.auto_bridge import AutoBridge

__all__ = [
"AutoBridge",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Copyright (c) 2025, BAAI. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to me that this file was largely adapted from flagscale/train/megatron/nemo_bridge/models/conversion/auto_bridge.py. We copy-pasted the source and we are claiming copyright for this code. This is not acceptable.

We can borrow code from other projects, provided that the license terms grant us this right. In that case, we still have to pay credit to the original authors. We are obliged to mention their copyrights.

There are some weird characters in this file which was obviously a character conversion problem during copy/paste. Please fix them as well.


from megatron.bridge import AutoBridge as OriginalAutoBridge
import transformers
import torch.distributed as dist
from transformers import AutoModelForCausalLM
from transformers.configuration_utils import PretrainedConfig

from megatron.core.transformer.module import MegatronModule
from megatron.nemo_bridge.models.conversion import model_bridge
from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge

from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource
from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
from megatron.bridge.models.conversion.utils import get_causal_lm_class_via_auto_map

from typing import TypeVar, Union
from pathlib import Path
Comment on lines +3 to +19
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import order in this module is likely to fail Ruff's isort rules (I001): standard-library imports (typing, pathlib) should be grouped at the top, followed by third-party (transformers, torch), then local (megatron...). Please run isort/ruff on this file to normalize the import blocks.

Suggested change
from megatron.bridge import AutoBridge as OriginalAutoBridge
import transformers
import torch.distributed as dist
from transformers import AutoModelForCausalLM
from transformers.configuration_utils import PretrainedConfig
from megatron.core.transformer.module import MegatronModule
from megatron.nemo_bridge.models.conversion import model_bridge
from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource
from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
from megatron.bridge.models.conversion.utils import get_causal_lm_class_via_auto_map
from typing import TypeVar, Union
from pathlib import Path
from pathlib import Path
from typing import TypeVar, Union
import torch.distributed as dist
import transformers
from transformers import AutoModelForCausalLM
from transformers.configuration_utils import PretrainedConfig
from megatron.bridge import AutoBridge as OriginalAutoBridge
from megatron.bridge.models.conversion.utils import get_causal_lm_class_via_auto_map
from megatron.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource
from megatron.core.transformer.module import MegatronModule
from megatron.nemo_bridge.models.conversion import model_bridge
from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.nemo_bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM

Copilot uses AI. Check for mistakes.
MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule)


class AutoBridge(OriginalAutoBridge):

def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig):
if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)):
raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance")
self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained
super().__init__(hf_pretrained)

@classmethod
def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge":
"""
Load an AutoBridge from a pretrained model, automatically detecting the model type.
"""
# First load just the config to check architecture support
# Use thread-safe config loading to prevent race conditions
config = safe_load_config_with_retry(path, trust_remote_code=kwargs.get("trust_remote_code", False))

cls._validate_config(config, str(path))

try:
return cls(PreTrainedCausalLM.from_pretrained(path, **kwargs))
except Exception as e:
raise ValueError(f"Failed to load model with AutoBridge: {e}") from e

@classmethod
def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> None:
# Check if this is a causal LM model
if not cls.supports(config):
architectures = getattr(config, "architectures", [])
raise ValueError(
f"\n�~\~W Model architecture not supported by AutoBridge\n\n"
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message strings contain garbled replacement characters (e.g., �~\~W). This will render poorly for users and makes logs hard to read; replace with the intended symbol/text (e.g., plain "Error:" / "Unsupported:" or a proper Unicode character) and ensure the file encoding is UTF-8 clean.

Suggested change
f"\n�~\~W Model architecture not supported by AutoBridge\n\n"
f"\nError: Model architecture not supported by AutoBridge\n\n"

Copilot uses AI. Check for mistakes.
f"Model: {path}\n"
f"Architectures: {architectures}\n\n"
f"AutoBridge only supports models with architectures ending in 'ForCausalLM' or"
f"'ForConditionalGeneration' or 'NemotronH_Nano_VL_V2'.\n"
f"Found architectures that don't match this pattern.\n\n"
f"If this is a different model type (e.g., Vision, Sequence-to-Sequence),\n"
f"you may need to use a different bridge class."
)

# Check if we have an implementation for this specific architecture
architecture = None
for arch_name in config.architectures:
if arch_name.endswith(
("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2")
):
architecture = arch_name
break

if architecture:
# Try auto_map first
arch_class = (
get_causal_lm_class_via_auto_map(model_name_or_path=path, config=config)
if path
else None
)
if arch_class is not None:
# For auto_map models, use class-name string
arch_key = getattr(arch_class, "__name__", str(arch_class))
else:
try:
arch_class = getattr(transformers, architecture)
arch_key = arch_class
except AttributeError:
# Fall back to name-based registration
arch_key = architecture

# Test if we have a registered implementation (type or class-name string)
has_implementation = False
if hasattr(model_bridge.get_model_bridge, "_exact_types"):
registry = model_bridge.get_model_bridge._exact_types
if isinstance(arch_key, str):
has_implementation = arch_key in registry
else:
has_implementation = (arch_key in registry) or (
getattr(arch_key, "__name__", None) in registry
)

if not has_implementation:
raise ValueError(
f"\n�~\~W Model architecture '{architecture}' is not yet supported\n\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these weird characters?
There are some other similar cases in this string.

f"Model: {path}\n"
f"Architecture: {architecture}\n\n"
+ f"\n\nTo add support for {architecture}, you need to:\n"
f"1. Create a new bridge class that inherits from MegatronModelBridge\n"
f"2. Implement the required methods (provider_bridge, mapping_registry)\n"
f"3. Register it with @MegatronModelBridge.register_bridge decorator\n\n"
f"Example implementation:\n"
f" from megatron.nemo_bridge.models.conversion.model_bridge import MegatronModelBridge\n"
f" from transformers import {architecture}\n"
f" from megatron.core.models.gpt import GPTModel\n\n"
f" @MegatronModelBridge.register_bridge(source={architecture}, target=GPTModel)\n"
f" class Megatron{architecture.replace('ForCausalLM', '')}Bridge(MegatronModelBridge):\n"
f" def provider_bridge(self, hf_pretrained):\n"
f" # Return a ModelProvider instance\n"
f" ...\n\n"
f" def mapping_registry(self):\n"
f" # Return a MegatronMappingRegistry with weight mappings\n"
f" ...\n\n"
f"For reference implementations, see:\n"
f" �~@� src/megatron/bridge/models/llama/llama_bridge.py\n"
f" �~@� src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py"
) from None


@classmethod
def from_hf_config(cls, config: PretrainedConfig, modeling_path=None, model_name= None) -> "AutoBridge":

cls._validate_config(config)
model = PreTrainedCausalLM()
model.config = config
import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device

hf_model = None
if modeling_path :
import os,sys
import importlib.util
abs_model_path = os.path.abspath(modeling_path)
Comment on lines +138 to +142
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple Ruff/pycodestyle violations in this block (e.g., if modeling_path : / elif ... and config : whitespace before :, and import os,sys multiple imports on one line + missing whitespace after comma). These will fail lint; please normalize to standard formatting (if ...: and one import per line).

Copilot uses AI. Check for mistakes.
model_dir = os.path.dirname(abs_model_path)
if model_dir not in sys.path:
sys.path.insert(0, model_dir)
module_name = "dynamic_model_module"
try:
spec = importlib.util.spec_from_file_location(module_name, abs_model_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
model_class = getattr(module, model_name)
with init_empty_weights():
hf_model = model_class._from_config(model.config)
for name, param in hf_model.named_parameters():
set_module_tensor_to_device(
hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype)
)
except Exception as e:
print(f"import module error: {e}")

elif not modeling_path and config :
with init_empty_weights():
hf_model = AutoModelForCausalLM.from_config(model.config)

for name, param in hf_model.named_parameters():
set_module_tensor_to_device(
hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype)
)
else:
raise ValueError("Need one args, model_path or config, to build HF model.")
model.model = hf_model
return cls(model)
Comment on lines +139 to +174
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_hf_config swallows exceptions when importing/building the dynamic HF model (print(f"import module error: {e}")) and then continues, potentially returning an AutoBridge with hf_model still None. This will cause failures later in a much less debuggable place. Raise an exception (or re-raise with context) if the model class can't be loaded, and ensure hf_model is non-None before returning.

Copilot uses AI. Check for mistakes.

def load_hf_weights(
self, model: list[MegatronModelT], hf_path: str | Path | None = None
) -> None:
if hf_path is None:
if not isinstance(self.hf_pretrained, PreTrainedCausalLM):
raise ValueError(
"hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance"
)
pre_trained = self.hf_pretrained
else:
pre_trained = PreTrainedCausalLM.from_pretrained(hf_path)
# Preserve trust_remote_code setting from the original bridge instance
trust_remote_code = getattr(self.hf_pretrained, 'trust_remote_code', False)
pre_trained = PreTrainedCausalLM.from_pretrained(
hf_path, trust_remote_code=trust_remote_code
)
Comment on lines +185 to +191
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_hf_weights calls PreTrainedCausalLM.from_pretrained(hf_path) twice (the first result is immediately overwritten). This adds unnecessary IO/memory overhead; remove the redundant first call and load once with the correct trust_remote_code value.

Copilot uses AI. Check for mistakes.
# self._model_bridge.load_weights_hf_to_megatron(model, pre_trained)
self._model_bridge.load_weights_hf_to_megatron(pre_trained, model)

return model

def save_hf_weights(
self,
model: list[MegatronModelT],
path: str | Path,
show_progress: bool = True,
strict: bool = True,
) -> None:
if dist.is_available() and dist.is_initialized():
dist.barrier()
dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model))
generator = model_bridge.stream_weights_megatron_to_hf(
dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress
)
source = SafeTensorsStateSource(path)
# Check if the state source is SafeTensorsStateSource for streaming save.
if (
hasattr(self.hf_pretrained, "state")
and hasattr(self.hf_pretrained.state, "source")
# and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource)
):
# self.hf_pretrained.state.source.save_generator(generator, path, strict=strict)
source.save_generator(generator, path, strict=strict)
else:
raise ValueError(
"The state source is not a SafeTensorsStateSource, cannot save in streaming mode."
)

if dist.is_available() and dist.is_initialized():
dist.barrier()

@property
def _model_bridge(self) -> "MegatronModelBridge":
return model_bridge.get_model_bridge(self._causal_lm_architecture)

def convert_mg2hf_config(margs, save_path, model_type):
assert model_type is not None, "model_type is None"
if hasattr(model_bridge.get_model_bridge, "_exact_types"):
registry = model_bridge.get_model_bridge._exact_types
if isinstance(model_type, str):
has_implementation = model_type in registry
else:
has_implementation = (model_type in registry) or (getattr(model_type, "__name__", None) in registry)

if not has_implementation:
raise ValueError(f"\n✗ Model architecture '{model_type}' is not yet supported\n\n")
model_bridge_class = model_bridge.get_model_bridge(model_type)
return model_bridge_class.save_args_mg2hf(margs, save_path)
Loading