Skip to content

Commit d3b829d

Browse files
committed
clean
1 parent 4ffe705 commit d3b829d

File tree

2 files changed

+6
-42
lines changed

2 files changed

+6
-42
lines changed

verl/models/mcore/config_converter.py

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch.nn.functional as F
2222
from megatron.core import parallel_state as mpu
2323
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
24-
from megatron.core.transformer.enums import AttnBackend
2524
from transformers import PretrainedConfig
2625

2726

@@ -38,7 +37,6 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
3837
Returns:
3938
TransformerConfig with common parameters
4039
"""
41-
from megatron.core import parallel_state as mpu
4240

4341
# Common parallel state parameters
4442
overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
@@ -56,6 +54,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
5654
"hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0),
5755
"kv_channels": getattr(hf_config, "head_dim", None),
5856
"layernorm_epsilon": hf_config.rms_norm_eps,
57+
"add_bias_linear": False,
5958
# Activation and normalization
6059
"activation_func": F.silu,
6160
"normalization": "RMSNorm",
@@ -190,51 +189,16 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, *
190189
def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
191190
# Qwen2_5_VLForConditionalGeneration
192191

193-
overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
194-
batch_p2p_comm = False
195-
transformer_config = TransformerConfig(
196-
num_layers=hf_config.num_hidden_layers,
197-
hidden_size=hf_config.hidden_size,
198-
ffn_hidden_size=hf_config.intermediate_size,
199-
num_attention_heads=hf_config.num_attention_heads,
200-
num_query_groups=hf_config.num_key_value_heads,
201-
attention_dropout=hf_config.attention_dropout,
202-
hidden_dropout=getattr(hf_config, "hidden_dropout", 0.0),
203-
activation_func=F.silu,
204-
normalization="RMSNorm",
205-
gated_linear_unit=True,
206-
use_cpu_initialization=False,
192+
return _get_base_transformer_config(
193+
hf_config=hf_config,
194+
dtype=dtype,
207195
add_bias_linear=False,
208-
pipeline_dtype=dtype,
209-
params_dtype=dtype,
210-
variable_seq_lengths=True,
211-
masked_softmax_fusion=True,
212-
bf16=dtype is torch.bfloat16,
213-
layernorm_epsilon=hf_config.rms_norm_eps,
214-
# parallel config
215-
tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),
216-
pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),
217-
virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),
218-
context_parallel_size=mpu.get_context_parallel_world_size(),
219-
overlap_p2p_comm=overlap_p2p_comm,
220-
batch_p2p_comm=batch_p2p_comm,
221-
sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,
222-
attention_backend=AttnBackend.flash,
223-
# ?
224-
attention_softmax_in_fp32=False,
225-
persist_layer_norm=True,
226-
bias_dropout_fusion=True,
227-
distribute_saved_activations=False,
228-
cp_comm_type="p2p",
229-
# moe specific
230-
moe_token_dispatcher_type="alltoall",
231196
# qwen specific
232197
add_qkv_bias=True,
233198
mrope_section=hf_config.rope_scaling["mrope_section"],
199+
**override_transformer_config_kwargs,
234200
)
235201

236-
return transformer_config
237-
238202

239203
def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
240204
# Llama4ForConditionalGeneration

verl/models/mcore/model_initializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def initialize(
182182
from megatron.core.models.gpt.moe_module_specs import MLPSubmodules
183183
from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
184184

185-
from .qwen2_5_vl.model import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config
185+
from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config
186186

187187
vision_transformer_config = get_vision_model_config(deepcopy(tfconfig))
188188
vision_transformer_config.pipeline_model_parallel_size = 1

0 commit comments

Comments
 (0)