2121import torch .nn .functional as F
2222from megatron .core import parallel_state as mpu
2323from megatron .core .transformer import MLATransformerConfig , TransformerConfig
24- from megatron .core .transformer .enums import AttnBackend
2524from transformers import PretrainedConfig
2625
2726
@@ -38,7 +37,6 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
3837 Returns:
3938 TransformerConfig with common parameters
4039 """
41- from megatron .core import parallel_state as mpu
4240
4341 # Common parallel state parameters
4442 overlap_p2p_comm = mpu .get_virtual_pipeline_model_parallel_world_size () is not None and mpu .get_virtual_pipeline_model_parallel_world_size () > 1
@@ -56,6 +54,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
5654 "hidden_dropout" : getattr (hf_config , "hidden_dropout" , 0.0 ),
5755 "kv_channels" : getattr (hf_config , "head_dim" , None ),
5856 "layernorm_epsilon" : hf_config .rms_norm_eps ,
57+ "add_bias_linear" : False ,
5958 # Activation and normalization
6059 "activation_func" : F .silu ,
6160 "normalization" : "RMSNorm" ,
@@ -190,51 +189,16 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, *
190189def hf_to_mcore_config_qwen2_5_vl (hf_config : PretrainedConfig , dtype : torch .dtype , ** override_transformer_config_kwargs ) -> TransformerConfig :
191190 # Qwen2_5_VLForConditionalGeneration
192191
193- overlap_p2p_comm = mpu .get_virtual_pipeline_model_parallel_world_size () is not None and mpu .get_virtual_pipeline_model_parallel_world_size () > 1
194- batch_p2p_comm = False
195- transformer_config = TransformerConfig (
196- num_layers = hf_config .num_hidden_layers ,
197- hidden_size = hf_config .hidden_size ,
198- ffn_hidden_size = hf_config .intermediate_size ,
199- num_attention_heads = hf_config .num_attention_heads ,
200- num_query_groups = hf_config .num_key_value_heads ,
201- attention_dropout = hf_config .attention_dropout ,
202- hidden_dropout = getattr (hf_config , "hidden_dropout" , 0.0 ),
203- activation_func = F .silu ,
204- normalization = "RMSNorm" ,
205- gated_linear_unit = True ,
206- use_cpu_initialization = False ,
192+ return _get_base_transformer_config (
193+ hf_config = hf_config ,
194+ dtype = dtype ,
207195 add_bias_linear = False ,
208- pipeline_dtype = dtype ,
209- params_dtype = dtype ,
210- variable_seq_lengths = True ,
211- masked_softmax_fusion = True ,
212- bf16 = dtype is torch .bfloat16 ,
213- layernorm_epsilon = hf_config .rms_norm_eps ,
214- # parallel config
215- tensor_model_parallel_size = mpu .get_tensor_model_parallel_world_size (),
216- pipeline_model_parallel_size = mpu .get_pipeline_model_parallel_world_size (),
217- virtual_pipeline_model_parallel_size = mpu .get_virtual_pipeline_model_parallel_world_size (),
218- context_parallel_size = mpu .get_context_parallel_world_size (),
219- overlap_p2p_comm = overlap_p2p_comm ,
220- batch_p2p_comm = batch_p2p_comm ,
221- sequence_parallel = mpu .get_tensor_model_parallel_world_size () > 1 ,
222- attention_backend = AttnBackend .flash ,
223- # ?
224- attention_softmax_in_fp32 = False ,
225- persist_layer_norm = True ,
226- bias_dropout_fusion = True ,
227- distribute_saved_activations = False ,
228- cp_comm_type = "p2p" ,
229- # moe specific
230- moe_token_dispatcher_type = "alltoall" ,
231196 # qwen specific
232197 add_qkv_bias = True ,
233198 mrope_section = hf_config .rope_scaling ["mrope_section" ],
199+ ** override_transformer_config_kwargs ,
234200 )
235201
236- return transformer_config
237-
238202
239203def hf_to_mcore_config_llama4 (hf_config : PretrainedConfig , dtype : torch .dtype , ** override_transformer_config_kwargs ) -> TransformerConfig :
240204 # Llama4ForConditionalGeneration
0 commit comments