diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index c99312fb3ec..ed30e8c5310 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -48,6 +48,8 @@ on: # Entrypoints - ".github/workflows/model.yml" - "tests/special_distributed/test_fsdp_ckpt.py" + - "tests/special_distributed/test_mcore_config_converter.py" + - "tests/special_distributed/test_tensor_dict.py" - "tests/models/**" - "tests/special_distributed/run_all.sh" @@ -142,3 +144,35 @@ jobs: - name: Running FSDP2 rmpad model tests on 8 L20 GPUs + latest flash_attn run: | STRATEGY=fsdp2 torchrun --nproc_per_node=8 tests/special_distributed/test_fsdp_ckpt.py + + mcore_config_converter: + runs-on: [L20x8] + timeout-minutes: 20 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + pip install --upgrade "huggingface_hub[cli]" + - name: Download model config files + run: | + hf download Qwen/Qwen2.5-7B config.json --local-dir $HOME/configs/Qwen/Qwen2.5-7B + hf download Qwen/Qwen3-8B config.json --local-dir $HOME/configs/Qwen/Qwen3-8B + hf download deepseek-ai/deepseek-coder-1.3b-instruct config.json --local-dir $HOME/configs/deepseek-ai/deepseek-coder-1.3b-instruct + hf download Qwen/Qwen2-57B-A14B config.json --local-dir $HOME/configs/Qwen/Qwen2-57B-A14B + hf download Qwen/Qwen3-30B-A3B config.json --local-dir $HOME/configs/Qwen/Qwen3-30B-A3B + hf download deepseek-ai/DeepSeek-V3-Base config.json --local-dir $HOME/configs/deepseek-ai/DeepSeek-V3-Base + - name: Running mcore config converter tests on 8 L20 GPUs + run: | + torchrun --nproc_per_node=8 tests/special_distributed/test_mcore_config_converter.py \ No newline at end of file diff --git a/tests/special_distributed/test_mcore_config_converter.py b/tests/special_distributed/test_mcore_config_converter.py new file mode 100644 index 00000000000..2eaea9bdbcf --- /dev/null +++ b/tests/special_distributed/test_mcore_config_converter.py @@ -0,0 +1,101 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import megatron.core.parallel_state as mpu +import torch +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from transformers import AutoConfig, PretrainedConfig + +from verl.models.mcore import hf_to_mcore_config +from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group + +TEST_MODELS = [ + "Qwen/Qwen2.5-7B", # Qwen2 dense + "Qwen/Qwen3-8B", # Qwen3 dense + "deepseek-ai/deepseek-coder-1.3b-instruct", # deepseek dense + "Qwen/Qwen2-57B-A14B", # Qwen2 moe + "Qwen/Qwen3-30B-A3B", # Qwen3 moe + # "mistralai/Mixtral-8x7B-v0.1", # Mixtral # require authentication + "deepseek-ai/DeepSeek-V3-Base", # Deepseek V3 +] + + +def check_config_converter_results(tf_config: TransformerConfig | MLATransformerConfig, hf_config: PretrainedConfig): + assert tf_config.num_layers == hf_config.num_hidden_layers, ( + f"Number of layers mismatch: {tf_config.num_layers} != {hf_config.num_hidden_layers}" + ) + assert tf_config.hidden_size == hf_config.hidden_size, ( + f"Hidden size mismatch: {tf_config.hidden_size} != {hf_config.hidden_size}" + ) + assert tf_config.num_attention_heads == hf_config.num_attention_heads, ( + f"Number of attention heads mismatch: {tf_config.num_attention_heads} != {hf_config.num_attention_heads}" + ) + assert tf_config.num_query_groups == hf_config.num_key_value_heads, ( + f"Number of query groups mismatch: {tf_config.num_query_groups} != {hf_config.num_key_value_heads}" + ) + assert tf_config.ffn_hidden_size == hf_config.intermediate_size, ( + f"FFN hidden size mismatch: {tf_config.ffn_hidden_size} != {hf_config.intermediate_size}" + ) + assert tf_config.attention_dropout == hf_config.attention_dropout, ( + f"Attention dropout mismatch: {tf_config.attention_dropout} != {hf_config.attention_dropout}" + ) + assert tf_config.hidden_dropout == getattr(hf_config, "hidden_dropout", 0.0), ( + f"Hidden dropout mismatch: {tf_config.hidden_dropout} != {getattr(hf_config, 'hidden_dropout', 0.0)}" + ) + if getattr(hf_config, "head_dim", None) is not None: + assert tf_config.kv_channels == getattr(hf_config, "head_dim", None), ( + f"Head dim mismatch: {tf_config.kv_channels} != {getattr(hf_config, 'head_dim', None)}" + ) + assert tf_config.layernorm_epsilon == hf_config.rms_norm_eps, ( + f"Layernorm epsilon mismatch: {tf_config.layernorm_epsilon} != {hf_config.rms_norm_eps}" + ) + + +def modify_hf_config(name: str, hf_config: PretrainedConfig): + if name == "deepseek-ai/DeepSeek-V3-Base": + hf_config.num_nextn_predict_layers = 0 + hf_config.quantization_config = None + return hf_config + + +def test_mcore_config_converter(): + """ + Test the conversion of Hugging Face model configurations to MCore configurations. + """ + local_rank, rank, world_size = initialize_global_process_group() + mpu.initialize_model_parallel( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, + use_sharp=False, + context_parallel_size=2, + expert_model_parallel_size=1, + expert_tensor_parallel_size=None, + nccl_communicator_config_path=None, + ) + for model_name in TEST_MODELS: + print(f"testing {model_name}") + hf_config = AutoConfig.from_pretrained(os.path.expanduser(f"~/configs/{model_name}/config.json")) + hf_config = modify_hf_config(model_name, hf_config) + tf_config = hf_to_mcore_config(hf_config, torch.bfloat16) + check_config_converter_results(tf_config, hf_config) + + destroy_global_process_group() + + +if __name__ == "__main__": + test_mcore_config_converter() diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py index 95fe7e41660..9daf550cdb8 100644 --- a/verl/models/mcore/config_converter.py +++ b/verl/models/mcore/config_converter.py @@ -18,6 +18,7 @@ import warnings +from typing import TypeVar import torch import torch.nn.functional as F @@ -25,6 +26,8 @@ from megatron.core.transformer import MLATransformerConfig, TransformerConfig from transformers import PretrainedConfig +T = TypeVar("T", bound=TransformerConfig) + def _get_base_transformer_config( hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs @@ -131,7 +134,7 @@ def _get_mla_transformer_config( return base_config -def check_and_disable_incompatible_configs(original_config: dict) -> dict: +def check_and_construct_configs(original_config: dict, cls: type[T]) -> T: """ Check and disable incompatible configurations for older Megatron version. @@ -143,7 +146,7 @@ def check_and_disable_incompatible_configs(original_config: dict) -> dict: """ removed_keys = [] for key in original_config.keys(): - if not hasattr(TransformerConfig, key): + if not hasattr(cls, key): removed_keys.append(key) if removed_keys: warnings.warn( @@ -152,7 +155,9 @@ def check_and_disable_incompatible_configs(original_config: dict) -> dict: ) for key in removed_keys: original_config.pop(key) - return original_config + + print(f"Overridden {cls.__name__} init config: {original_config}") + return cls(**original_config) def hf_to_mcore_config_dense( @@ -172,9 +177,7 @@ def hf_to_mcore_config_dense( ) # override_transformer_config_kwargs as kwargs shall never be none args.update(override_transformer_config_kwargs) - args = check_and_disable_incompatible_configs(args) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) + return check_and_construct_configs(args, TransformerConfig) def hf_to_mcore_config_qwen2moe( @@ -208,9 +211,7 @@ def hf_to_mcore_config_qwen2moe( ) # override_transformer_config_kwargs as kwargs shall never be none args.update(override_transformer_config_kwargs) - args = check_and_disable_incompatible_configs(args) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) + return check_and_construct_configs(args, TransformerConfig) def hf_to_mcore_config_mixtral( @@ -243,9 +244,7 @@ def hf_to_mcore_config_mixtral( ) # override_transformer_config_kwargs as kwargs shall never be none args.update(override_transformer_config_kwargs) - args = check_and_disable_incompatible_configs(args) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) + return check_and_construct_configs(args, TransformerConfig) def hf_to_mcore_config_qwen3moe( @@ -277,9 +276,7 @@ def hf_to_mcore_config_qwen3moe( ) # override_transformer_config_kwargs as kwargs shall never be none args.update(override_transformer_config_kwargs) - args = check_and_disable_incompatible_configs(args) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) + return check_and_construct_configs(args, TransformerConfig) def hf_to_mcore_config_dpskv3( @@ -354,9 +351,7 @@ def hf_to_mcore_config_dpskv3( ) # override_transformer_config_kwargs as kwargs shall never be none args.update(override_transformer_config_kwargs) - args = check_and_disable_incompatible_configs(args) - transformer_config: MLATransformerConfig = MLATransformerConfig(**args) - print(f"Overridden MLA TF init config: {transformer_config}") + transformer_config = check_and_construct_configs(args, MLATransformerConfig) # MTP if "num_nextn_predict_layers" in hf_config: transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers @@ -380,8 +375,6 @@ def hf_to_mcore_config_qwen2_5_vl( ) # override_transformer_config_kwargs as kwargs shall never be none args.update(override_transformer_config_kwargs) - args = check_and_disable_incompatible_configs(args) - print(f"Overridden TF init config: {args}") return TransformerConfig(**args)