Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@


import warnings
from typing import TypeVar

import torch
import torch.nn.functional as F
from megatron.core import parallel_state as mpu
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
from transformers import PretrainedConfig

T = TypeVar("T", bound=TransformerConfig)


def _get_base_transformer_config(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
Expand Down Expand Up @@ -131,7 +134,7 @@ def _get_mla_transformer_config(
return base_config


def check_and_disable_incompatible_configs(original_config: dict) -> dict:
def check_and_construct_configs(original_config: dict, cls: type[T]) -> T:
"""
Check and disable incompatible configurations for older Megatron version.

Expand All @@ -143,7 +146,7 @@ def check_and_disable_incompatible_configs(original_config: dict) -> dict:
"""
removed_keys = []
for key in original_config.keys():
if not hasattr(TransformerConfig, key):
if not hasattr(cls, key):
removed_keys.append(key)
if removed_keys:
warnings.warn(
Expand All @@ -152,7 +155,9 @@ def check_and_disable_incompatible_configs(original_config: dict) -> dict:
)
for key in removed_keys:
original_config.pop(key)
return original_config

print(f"Overridden {cls.__name__} init config: {original_config}")
return cls(**original_config)


def hf_to_mcore_config_dense(
Expand All @@ -172,9 +177,7 @@ def hf_to_mcore_config_dense(
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
args = check_and_disable_incompatible_configs(args)
print(f"Overridden TF init config: {args}")
return TransformerConfig(**args)
return check_and_construct_configs(args, TransformerConfig)


def hf_to_mcore_config_qwen2moe(
Expand Down Expand Up @@ -208,9 +211,7 @@ def hf_to_mcore_config_qwen2moe(
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
args = check_and_disable_incompatible_configs(args)
print(f"Overridden TF init config: {args}")
return TransformerConfig(**args)
return check_and_construct_configs(args, TransformerConfig)


def hf_to_mcore_config_mixtral(
Expand Down Expand Up @@ -243,9 +244,7 @@ def hf_to_mcore_config_mixtral(
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
args = check_and_disable_incompatible_configs(args)
print(f"Overridden TF init config: {args}")
return TransformerConfig(**args)
return check_and_construct_configs(args, TransformerConfig)


def hf_to_mcore_config_qwen3moe(
Expand Down Expand Up @@ -277,9 +276,7 @@ def hf_to_mcore_config_qwen3moe(
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
args = check_and_disable_incompatible_configs(args)
print(f"Overridden TF init config: {args}")
return TransformerConfig(**args)
return check_and_construct_configs(args, TransformerConfig)


def hf_to_mcore_config_dpskv3(
Expand Down Expand Up @@ -354,9 +351,7 @@ def hf_to_mcore_config_dpskv3(
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
args = check_and_disable_incompatible_configs(args)
transformer_config: MLATransformerConfig = MLATransformerConfig(**args)
print(f"Overridden MLA TF init config: {transformer_config}")
transformer_config = check_and_construct_configs(args, MLATransformerConfig)
# MTP
if "num_nextn_predict_layers" in hf_config:
transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers
Expand All @@ -380,9 +375,7 @@ def hf_to_mcore_config_qwen2_5_vl(
)
# override_transformer_config_kwargs as kwargs shall never be none
args.update(override_transformer_config_kwargs)
args = check_and_disable_incompatible_configs(args)
print(f"Overridden TF init config: {args}")
return TransformerConfig(**args)
return check_and_construct_configs(args, TransformerConfig)


def hf_to_mcore_config_llama4(
Expand Down
Loading