diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index cc124c694df..af0e74562e3 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -26,6 +26,7 @@ on: # Entrypoints - ".github/workflows/e2e_ppo_trainer_megatron.yml" - "examples/data_preprocess/gsm8k.py" + - "examples/data_preprocess/geo3k.py" - "tests/e2e/run_ppo_trainer_megatron.sh" - "verl/trainer/main_ppo.py" - "verl/trainer/config/ppo_megatron_trainer.yaml" @@ -261,4 +262,35 @@ jobs: - name: clean up run: | rm -rf checkpoints - + e2e_ppo_trainer_megatron-qwen2_5vl-3b: + runs-on: [L20x8] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare Geo3k dataset + run: | + python3 examples/data_preprocess/geo3k.py + - name: Prepare dist_ckpt of Qwen2.5-VL-3B, only supports dist_ckpt + run: | + python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-VL-3B-Instruct --output_path checkpoints/verl-test/qwen2.5-vl-3b-megatron + - name: Running Geo3k E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) + run: | + ray stop --force + TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet MAX_PROMPT_LENGTH=1024 MAX_RESPONSE_LENGTH=2048 MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False SKIP_SAVE_HF_MODEL=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 COMMON_TP=2 USE_DIST_CKPT=true DIST_CKPT_PATH=checkpoints/verl-test/qwen2.5-vl-3b-megatron bash tests/e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh new file mode 100644 index 00000000000..63791481657 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh @@ -0,0 +1,56 @@ +set -x +ENGINE=${1:-vllm} +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +HF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct +DIST_CKPT_PATH=${DIST_CKPT_PATH} + +# convert HF model to verl format +# python scripts/converter_hf_to_verl.py --hf_model_path $HF_MODEL_PATH --output_dir $DIST_CKPT_PATH + +train_path=/data/geo3k/train.parquet +test_path=/data/geo3k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_path" \ + data.val_files="$test_path" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/scripts/converter_hf_to_mcore.py b/scripts/converter_hf_to_mcore.py index 897b0f8a748..d363de26f3a 100644 --- a/scripts/converter_hf_to_mcore.py +++ b/scripts/converter_hf_to_mcore.py @@ -23,7 +23,7 @@ from megatron.core.dist_checkpointing.serialization import StrictHandling from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig from verl.models.mcore import hf_to_mcore_config from verl.utils.megatron_utils import get_model @@ -146,22 +146,108 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config) model.output_layer.weight.copy_(hf_model.lm_head.weight) +def safe_copy( + src_tensor: torch.Tensor, + dst_tensor: torch.Tensor, + skip_dtype_assert: bool = False, +): + if not skip_dtype_assert: + if src_tensor.dtype != dst_tensor.dtype: + raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") + assert src_tensor.shape == dst_tensor.shape + dst_tensor.data.copy_(src_tensor.data) + return src_tensor.numel() + + +@torch.inference_mode() +def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config): + mgmodel = mgmodel.bfloat16() + hfmodel = hfmodel.bfloat16() + num_attention_heads = hf_config.num_attention_heads + num_query_groups = hf_config.num_key_value_heads + hidden_size = hf_config.hidden_size + head_dim = hidden_size // num_attention_heads + + # 1. vision model + hfvision = hfmodel.visual + mgvision = mgmodel.vision_model + vision_hidden_size = mgvision.config.hidden_size + vision_num_query_groups = mgvision.config.num_query_groups + vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads + copied_numel = 0 + safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq) + copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight) + for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers): + # norm1 --> linear_qkv.norm + copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight) + # norm2 --> mlp.linear_fc1.norm + copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight) + # qkv --> self_attention.linear_qkv + converted_weight = hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size).transpose(0, 1).flatten(1, 2).reshape(-1, vision_hidden_size).contiguous() + copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight) + converted_bias = hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1).transpose(0, 1).flatten(1, 2).view(-1).contiguous() + copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias) + # proj --> self_attention.linear_proj + copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight) + copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias) + # mlp --> mlp: gate + fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight]) + fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias]) + copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight) + copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias) + copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight) + copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias) + + # 2. vision projector + hfprojector = hfvision.merger + mgprojector = mgvision.projection + copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight) + + copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight) + copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias) + copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight) + copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias) + n_params = sum([t.numel() for t in hfvision.state_dict().values()]) + assert n_params == copied_numel + # 3. llm [just Qwen2] + hfllm = hfmodel.model + mgllm = mgmodel.language_model + copied_numel = 0 + copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight) + for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers): + copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight) + + q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) + k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) + v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) + qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous() + copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight) + + q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1) + k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1) + v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1) + qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous() + copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias) + copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight) + + fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight]) + copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight) + + copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight) + copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight) + + copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight) + if not hf_config.tie_word_embeddings: + safe_copy(hfmodel.lm_head.weight, mgllm.output_layer.weight) + + n_params = sum([t.numel() for t in hfllm.state_dict().values()]) + + assert n_params == copied_numel + + @torch.no_grad() def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig): warnings.warn("MTP model is not supported yet", stacklevel=2) - - def safe_copy( - src_tensor: torch.Tensor, - dst_tensor: torch.Tensor, - skip_dtype_assert: bool = False, - ): - if not skip_dtype_assert: - if src_tensor.dtype != dst_tensor.dtype: - raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") - assert src_tensor.shape == dst_tensor.shape - dst_tensor.data.copy_(src_tensor.data) - return src_tensor.numel() - model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight) for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)): print(layer_idx) @@ -264,14 +350,20 @@ def megatron_model_provider(pre_process, post_process): with warnings.catch_warnings(): warnings.simplefilter("ignore") + from transformers import AutoModelForCausalLM, AutoModelForImageTextToText # init hf model - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code) + if "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: + hf_model = AutoModelForImageTextToText.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code) + else: + hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code) hf_state_dict = hf_model.state_dict() # load hf state dict to megatron model if "Qwen2MoeForCausalLM" in hf_config.architectures: convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config) + elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: + convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config) elif "DeepseekV3ForCausalLM" in hf_config.architectures: convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig) elif "Qwen3MoeForCausalLM" in hf_config.architectures: diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py index 2e21f3ed5e1..b7208c59848 100644 --- a/verl/models/mcore/config_converter.py +++ b/verl/models/mcore/config_converter.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F +from megatron.core import parallel_state as mpu from megatron.core.transformer import MLATransformerConfig, TransformerConfig from transformers import PretrainedConfig @@ -36,7 +37,6 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype Returns: TransformerConfig with common parameters """ - from megatron.core import parallel_state as mpu # Common parallel state parameters overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 @@ -54,6 +54,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), "kv_channels": getattr(hf_config, "head_dim", None), "layernorm_epsilon": hf_config.rms_norm_eps, + "add_bias_linear": True, # Activation and normalization "activation_func": F.silu, "normalization": "RMSNorm", @@ -276,7 +277,17 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, * def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: # Qwen2_5_VLForConditionalGeneration - raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet") + + args = _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + add_bias_linear=False, + # qwen specific + add_qkv_bias=True, + mrope_section=hf_config.rope_scaling["mrope_section"], + **override_transformer_config_kwargs, + ) + return TransformerConfig(**args) def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index 94b86462b33..9cd8bb6a7c8 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -29,6 +29,7 @@ def gptmodel_forward( pack_seqs=True, logits_processor=None, logits_processor_args: dict = None, + **kwargs, ): """Default forward pass for GPT models with optional sequence packing.""" pre_process = unwrap_model(model).pre_process @@ -60,6 +61,54 @@ def gptmodel_forward( return output -def gptmodel_forward_qwen2_5_vl(*args, **kwargs): - """Forward pass for Qwen2.5 VL model (not implemented).""" - raise NotImplementedError("VLM is not supported yet") +def gptmodel_forward_qwen2_5_vl( + model, + input_ids, + attention_mask, + position_ids, + sequence_parallel, + value_model=False, + pack_seqs=True, + multi_modal_inputs=None, + logits_processor=None, + logits_processor_args: dict = None, + **kwargs, +): + from megatron.core import parallel_state as mpu + + assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + if pack_seqs: + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device), + image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device), + ) + + if post_process and logits_processor is not None: + args = {k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] for k, v in logits_processor_args.items()} + output_dict = logits_processor(output_orig, **args) + output = {k: postprocess_packed_seqs(v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) for k, v in output_dict.items()} + else: + output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) + else: + batch_size, sequence_length = attention_mask.shape + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process) + output = model( + input_ids=new_input_ids, + position_ids=new_position_ids, + attention_mask=new_attention_mask, + pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device), + image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device), + ) + output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process) + if value_model and post_process: + output = output[..., 0] + return output diff --git a/verl/models/mcore/model_initializer.py b/verl/models/mcore/model_initializer.py index 7edbfb482a2..e9afc05105d 100644 --- a/verl/models/mcore/model_initializer.py +++ b/verl/models/mcore/model_initializer.py @@ -193,4 +193,63 @@ class Qwen25VLModel(BaseModelInitializer): """Initializer for Qwen2.5 VL models.""" def get_transformer_layer_spec(self): - raise NotImplementedError("VLM is not supported yet") + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize( + self, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False, + **extra_kwargs, + ): + tfconfig = self.tfconfig + hf_config = self.hf_config + # Qwen2_5_VLForConditionalGeneration + from copy import deepcopy + + transformer_layer_spec = self.get_transformer_layer_spec() + + from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear + from megatron.core.models.gpt.moe_module_specs import MLPSubmodules + from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec + + from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config + + vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) + vision_transformer_config.pipeline_model_parallel_size = 1 + vision_transformer_config.first_pipeline_num_layers = None + + vision_projection_config = get_vision_projection_config(deepcopy(tfconfig), vision_transformer_config.hidden_size, spatial_merge_size=hf_config.vision_config.spatial_merge_size) + vision_projection_layer_spec = MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + + qwen25_vl_model = Qwen2_5VLModel( + language_transformer_config=tfconfig, + language_transformer_layer_spec=transformer_layer_spec, + language_vocab_size=hf_config.vocab_size, + language_max_sequence_length=hf_config.max_position_embeddings, + vision_transformer_config=vision_transformer_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + language_rotary_base=hf_config.rope_theta, + pre_process=pre_process, + post_process=post_process, + add_decoder=True, + add_encoder=True, + parallel_output=True, + language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + qwen25_vl_model.language_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, output_size=1, config=tfconfig) + + return qwen25_vl_model diff --git a/verl/models/mcore/qwen2_5_vl/__init__.py b/verl/models/mcore/qwen2_5_vl/__init__.py new file mode 100644 index 00000000000..8842d0249e1 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .model import Qwen2_5VLModel +from .vision_config import get_vision_model_config, get_vision_projection_config + +__all__ = ["Qwen2_5VLModel", "get_vision_model_config", "get_vision_projection_config"] diff --git a/verl/models/mcore/qwen2_5_vl/attention.py b/verl/models/mcore/qwen2_5_vl/attention.py new file mode 100644 index 00000000000..4b9685f9bb4 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/attention.py @@ -0,0 +1,210 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.transformer.attention import * + +from .rope_utils import apply_rotary_pos_emb_absolute + + +class Qwen2_5VLSelfAttention(SelfAttention): + """ + Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute instead of apply_rotary_pos_emb + """ + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Perform a forward pass through the attention module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary + embedding tensor(s). + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) Attention output and bias. + + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert flash_decode_and_prefill_kernel is not None, "Internal use only: install package `nvidia_chunked_flash_attn`." + + # hidden_states: [sq, b, h] + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + if self.config.flash_decode and inference_context is not None and inference_context.is_decode_only() and not self.training and rotary_pos_cos is not None: + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) + else: + query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) + if k_pos_emb is not None: + key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + # Static batching attention kernel. + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + else: + # Dynamic batching attention kernel. + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() + + core_attn_out = self.flash_decode_and_prefill(q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths) + core_attn_out = core_attn_out.squeeze(0).unsqueeze(1) + core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + + return output, bias diff --git a/verl/models/mcore/qwen2_5_vl/model.py b/verl/models/mcore/qwen2_5_vl/model.py new file mode 100644 index 00000000000..a9ed4886464 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/model.py @@ -0,0 +1,304 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import torch +from megatron.core import InferenceParams, tensor_parallel +from megatron.core.models.gpt.gpt_model import GPTModel + +# from .transformer_config import Qwen2VLTransformerConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + +from .attention import Qwen2_5VLSelfAttention +from .vision_model import Qwen2_5VisionModel + + +# Note: This is under development and may be missing features. +class Qwen2_5VLModel(MegatronModule): + """Qwen2.5VL multi-modal model. + + Args: + language_transformer_config (TransformerConfig): Transformer config for the language model. + language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the language model. + language_vocab_size (int): Language model vocabulary size. + language_max_sequence_length (int): Language model maximum sequence length. This is used for positional embedding. + vision_transformer_config (TransformerConfig): Transformer config for the vision model. + vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the vision model. + vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to language model inputs. + vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision projection. + vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This is typically True for training and False for inference. + language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings in the language model. Defaults to 1.0. + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + img_h (int): The height of each image that the ViT will see. + img_w (int): The width of each image that the ViT will see. + patch_dim (int): The size of each patch side. + img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be inserted. Defaults to 0. + """ + + def __init__( + self, + language_transformer_config: TransformerConfig, + language_transformer_layer_spec: ModuleSpec, + language_vocab_size: int, + language_max_sequence_length: int, + vision_transformer_config: TransformerConfig, + vision_transformer_layer_spec: ModuleSpec, + vision_projection_config: TransformerConfig, + vision_projection_layer_spec: ModuleSpec, + vision_projection_type: str = "mlp", + parallel_output: bool = True, + language_rotary_percent: float = 1.0, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + language_rotary_base: int = 10000, + fp16_lm_cross_entropy: bool = False, + language_share_embeddings_and_output_weights: bool = False, + image_token_id: int = 151655, + video_token_id: int = 151656, + ) -> None: + super().__init__(config=language_transformer_config) + + # patch self_attention to use qwen2_5_vl attention + vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + for layer_spec in language_transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + + logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") + + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + + self.encoder_hidden_state = None + self.vision_model = None + self.vision_projection = None + self.language_model = None + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size + + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. + self.share_embeddings_and_output_weights = False + if self.pre_process: + self.vision_model = Qwen2_5VisionModel(vision_transformer_config, vision_transformer_layer_spec, vision_projection_config, vision_projection_layer_spec, projection_type=vision_projection_type, pre_process=True, post_process=True) + + self.language_model = GPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_vocab_size, + max_sequence_length=language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_rotary_base, + fp16_lm_cross_entropy=fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + ) + + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + + def shared_embedding_or_output_weight(self): + """This is a convenience method to surface the language model's word embeddings, which is + necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" + if self.add_decoder: + return self.language_model.shared_embedding_or_output_weight() + return None + + def set_input_tensor(self, input_tensor) -> None: + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False for the module's parameters. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. + """ + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_projection is not None: + modules.append(self.vision_projection) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + ) -> torch.Tensor: + """Forward function of the Qwen2VL model. + + Args: + image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len]. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + + video_start_index: + 0 -- all video + len(video_seq) -- all image + others -- mixture + *_input_mask: should not be None in the first PP stage + Returns: + output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. + """ + video_start_index = 0 + vision_grid_thw = None + vision_data = None + if image_grid_thw is not None: + image_mask = input_ids == self.image_token_id + vision_grid_thw = image_grid_thw + vision_data = pixel_values + video_start_index = image_mask.sum().item() + if video_grid_thw is not None: + video_mask = input_ids == self.video_token_id + vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) + vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) + video_start_index = image_mask.sum().item() + video_mask.sum().item() + use_inference_kv_cache = inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + use_inference_kv_cache = inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + if use_inference_kv_cache: + raise NotImplementedError() + + if self.pre_process: + vision_embeds = None + if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: + vision_embeds = self.vision_model( + vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) + grid_thw=vision_grid_thw, # should provided in each EPP stage + ) + + # If running inference, the language model KV cache will be updated for image token positions. + # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. + if inference_params is not None: + raise NotImplementedError() + # inference_params.key_value_memory_dict["image_tokens_count"] = ( + # vision_embeddings.shape[0] + # ) + + # If running inference, we can skip image token computation if they were computed already earlier for this sample. + if use_inference_kv_cache: + language_embeddings: torch.Tensor = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + # NOTE: why not cat here? is it the combined embeddings useless? + combined_embeddings = language_embeddings + elif vision_embeds is not None: + if video_start_index == 0: + image_embeds = None + video_embeds = vision_embeds + elif video_start_index == vision_embeds.shape[0]: + image_embeds = vision_embeds + video_embeds = None + elif 0 < video_start_index < vision_embeds.shape[0]: + image_embeds = vision_embeds[:video_start_index] + video_embeds = vision_embeds[video_start_index:] + else: + raise ValueError(f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got {video_start_index}") + + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + + if image_embeds is not None or video_embeds is not None: + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + if image_embeds is not None: + image_mask = (input_ids == self.image_token_id).contiguous() + if image_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[image_mask] = image_embeds.to(dtype=combined_embeddings.dtype, device=combined_embeddings.device) + if video_embeds is not None: + video_mask = (input_ids == self.video_token_id).contiguous() + if video_mask.sum() > 0: + combined_embeddings = combined_embeddings.clone() + combined_embeddings[video_mask] = video_embeds.to(dtype=combined_embeddings.dtype, device=combined_embeddings.device) + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + else: + combined_embeddings = self.language_model.embedding( + input_ids=input_ids, + position_ids=None, # NOTE: disable + ) # [text_seq_len, b, h_language] + if self.config.sequence_parallel: + combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = combined_embeddings.contiguous() + else: + combined_embeddings = None + from .rope_utils import get_rope_index + + position_ids, _ = get_rope_index(input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask) + + output = self.language_model( + input_ids=None, + position_ids=position_ids, # None in encoder + attention_mask=attention_mask, # None in encoder + decoder_input=combined_embeddings, # only not None in the first decoder PP stage + labels=labels, # only not None in the last decoder PP stage + # inference_params=inference_params, # currently always None + packed_seq_params=packed_seq_params, # currently always None + **(extra_block_kwargs or {}), + ) + + return output diff --git a/verl/models/mcore/qwen2_5_vl/rope_utils.py b/verl/models/mcore/qwen2_5_vl/rope_utils.py new file mode 100644 index 00000000000..1f74ae27fb3 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/rope_utils.py @@ -0,0 +1,249 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from megatron.core.models.common.embeddings.rope_utils import * +from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd +from torch import Tensor + +logger = logging.getLogger(__name__) + + +# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index +def get_rope_index( + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. + In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = 2 + tokens_per_second = 2 + image_token_id = 151655 + video_token_id = 151656 + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand(3, input_ids.shape[0], -1) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + +def apply_rotary_pos_emb_thd_absolute(t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + + +def apply_rotary_pos_emb_absolute( + t: Tensor, + freqs: Tensor, + config: TransformerConfig, + cu_seqlens: Optional[Tensor] = None, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + bshd (conventional) / thd (packed seq) format + + In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] + """ + + if config.apply_rope_fusion: + if cu_seqlens is None: + # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 + if freqs.shape[1] > 1: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return fused_apply_rotary_pos_emb(t, freqs) + else: + # NOTE: as expected, thd format can use bshd + return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) diff --git a/verl/models/mcore/qwen2_5_vl/vision_config.py b/verl/models/mcore/qwen2_5_vl/vision_config.py new file mode 100644 index 00000000000..bb36f10d088 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/vision_config.py @@ -0,0 +1,83 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state +from megatron.core.transformer import TransformerConfig + + +def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: + # Given a Transformer Config from decoder, build vision encoder config + # diff: out_hidden_size & intermediate_size + + # mlp: hidden_size -> intermediate_size -> embed_dim, silu + # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on + if config.num_layers in [28, 36]: + config.ffn_hidden_size = 3420 + else: + config.ffn_hidden_size = 3456 + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth + else: + config.num_layers = 32 # depth + config.num_attention_heads = 16 # num_heads + config.add_bias_linear = True # all nn.Linear has bias (MLP, attn) + config.add_qkv_bias = True # qkv_proj in attn has bias + config.hidden_size = 1280 # hidden_size + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + + # config.gated_linear_unit = False # no gated + # config.activation_func = quick_gelu # hidden_act + config.kv_channels = config.hidden_size // config.num_attention_heads + config.num_query_groups = config.num_attention_heads # no GQA + config.layernorm_zero_centered_gamma = False # False + config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) + config.bias_activation_fusion = False # no swiglu, set false + config.bias_dropout_fusion = False # no dropout, set false + config.attention_softmax_in_fp32 = True # use True + # config.normalization = 'LayerNorm' # use RMSNorm + config.seq_length = 1 + + config.tp_comm_overlap = False + config.sequence_parallel = False + config.temporal_patch_size = 2 + config.patch_size = 14 + config.in_channels = 3 + config.spatial_merge_size = 2 + + config.fullatt_block_indexes = [7, 15, 23, 31] + config._qwen2_5_vl_window_size = 112 + return config + + +def get_vision_projection_config(config: TransformerConfig, embed_dim: int, spatial_merge_size: int) -> TransformerConfig: + # merger: + # context_dim = hidden_size * merge_size**2 + # out_hidden_size = hidden_size + # context_dim -> context_dim -> out_hidden_size + # MLP: + # input_size -> ffn_hidden_size -> hidden_size + # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True) + config.gated_linear_unit = False + config.bias_activation_fusion = False + config.add_bias_linear = True + config.ffn_hidden_size = embed_dim * (spatial_merge_size**2) + config.activation_func = torch.nn.functional.gelu + config.tp_comm_overlap = False + config.sequence_parallel = False + return config diff --git a/verl/models/mcore/qwen2_5_vl/vision_model.py b/verl/models/mcore/qwen2_5_vl/vision_model.py new file mode 100644 index 00000000000..5bf1c37bba6 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/vision_model.py @@ -0,0 +1,283 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from megatron.core import InferenceParams +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import nn +from torch.nn import functional as F + +from .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs.float() + + +class Qwen2_5VisionModel(VisionModule): + """Qwen2.5 ViT vision model. + + Args: + transformer_config (TransformerConfig): Transformer config. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. + ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. + add_class_token (bool, optional): Include a class token. Defaults to True. + class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. + patch_dim (int): Image patch size. + img_h (int): Input image height. + img_w (int): Input image width. + """ + + def __init__(self, transformer_config: TransformerConfig, transformer_layer_spec: ModuleSpec, projection_config: TransformerConfig, projection_layer_spec: ModuleSpec, projection_type: str = "mlp", pre_process: bool = True, post_process: bool = False) -> None: + super().__init__(config=transformer_config) + + self.spatial_merge_size = transformer_config.spatial_merge_size + + embed_dim = transformer_config.hidden_size + num_heads = transformer_config.num_attention_heads + temporal_patch_size = transformer_config.temporal_patch_size + patch_size = transformer_config.patch_size + in_channels = transformer_config.in_channels + + self.patch_size = transformer_config.patch_size + self.fullatt_block_indexes = transformer_config.fullatt_block_indexes + self.window_size = transformer_config._qwen2_5_vl_window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.max_sequence_length = transformer_config.seq_length + self.patch_embed = PatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + head_dim = embed_dim // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.model_type = ModelType.encoder_or_decoder + self.pre_process = pre_process + self.post_process = post_process + + # Transformer layers. + # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting pipeline parallelism. + # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here. + self.decoder = TransformerBlock(config=transformer_config, spec=transformer_layer_spec, pre_process=self.pre_process, post_process=self.post_process, post_layer_norm=True) + + self.merge_hidden_size = projection_config.ffn_hidden_size + self.square_merge_size = self.merge_hidden_size // embed_dim + + if self.post_process: + self.projection = MultimodalProjector(projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size) + else: + self.projection = None + + self.input_tensor = None + + def set_input_tensor(self, input_tensor: torch.Tensor) -> None: + """Sets input tensor to the model. + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + if self.pre_process: # always True + self.input_tensor = input_tensor + else: + raise NotImplementedError() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + vision_data: Optional[torch.Tensor], + grid_thw: torch.Tensor, + inference_params: Optional[InferenceParams] = None, + extra_block_kwargs: dict = None, + ) -> torch.Tensor: + """Forward function of the Qwen2 Vision Model. This function passes the input tensors + through the embedding layer and then the transformer. + + Args: + x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame + packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + + Returns: + x (torch.Tensor): output after final transformer block of shape [b, s, h]. + """ + assert grid_thw is not None + assert self.input_tensor is None + assert inference_params is None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + vision_data = self.patch_embed(vision_data) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=vision_data.device, + dtype=torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = vision_data.size() + vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + vision_data = vision_data[window_index, :, :] + vision_data = vision_data.reshape(seq_len, 1, -1) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) + + hidden_states = self.decoder( + hidden_states=vision_data, + attention_mask=None, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens), + packed_seq_params_full=self.build_packed_seq_params(grid_thw), + fullatt_block_indexes=self.fullatt_block_indexes, + **(extra_block_kwargs or {}), + ) + + hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size)) + reverse_indices = torch.argsort(window_index) + return hidden_states[reverse_indices, :] + + def build_packed_seq_params( + self, + grid_thw: Optional[torch.Tensor], + cu_seqlens: Optional[torch.Tensor] = None, + ) -> PackedSeqParams: + # NOTE: each frame is a sequence (rather than each grid) + if grid_thw is not None: + seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + cu_seqlens = seqlens.cumsum(dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() + else: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + max_seqlen_q = seqlens.max() + return PackedSeqParams(cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, qkv_format="thd", max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_q) diff --git a/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py new file mode 100644 index 00000000000..883db6ce309 --- /dev/null +++ b/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py @@ -0,0 +1,241 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from megatron.core.transformer.transformer_block import * + + +class Qwen2_5VisionTransformerBlock(TransformerBlock): + def _checkpointed_forward(self, hidden_states: Tensor, attention_mask: Tensor, context: Tensor, context_mask: Tensor, rotary_pos_emb: Tensor, attention_bias: Tensor, packed_seq_params: PackedSeqParams, packed_seq_params_full: PackedSeqParams, fullatt_block_indexes): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + for index in range(start, end): + if index in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + layer = self._get_layer(index) + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params_now, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.config.recompute_method == "uniform": + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + self.config.recompute_num_layers)) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == "block": + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if layer_idx >= recompute_skip_num_layers and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers: + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, context_mask, rotary_pos_emb) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + packed_seq_params_full: PackedSeqParams = None, + fullatt_block_indexes=None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Update the inference parameters with the current batch size in case it is variable + if inference_context and not self.training: + inference_context.current_batch_size = hidden_states.size(1) + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with rng_context, outer_fp8_context: + # Forward pass. + if self.config.recompute_granularity == "full" and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + packed_seq_params_full=packed_seq_params_full, + fullatt_block_indexes=fullatt_block_indexes, + ) + else: + for l_no, layer in enumerate(self.layers): + inner_fp8_context = get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() + if l_no in fullatt_block_indexes: + packed_seq_params_now = packed_seq_params_full + else: + packed_seq_params_now = packed_seq_params + with self.offload_context, inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params_now, + sequence_len_offset=sequence_len_offset, + ) + + if torch.is_grad_enabled() and self.config.cpu_offloading and self.group_prefetch_offload_commit_async is not None: + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return hidden_states diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index cc27cc2bc67..3997bfe064a 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -35,6 +35,7 @@ ) from .model_forward import ( gptmodel_forward, + gptmodel_forward_qwen2_5_vl, ) from .model_initializer import ( BaseModelInitializer, @@ -49,6 +50,7 @@ McoreToHFWeightConverterDense, McoreToHFWeightConverterDpskv3, McoreToHFWeightConverterMixtral, + McoreToHFWeightConverterQwen2_5_VL, McoreToHFWeightConverterQwen2Moe, McoreToHFWeightConverterQwen3Moe, ) @@ -77,6 +79,7 @@ class SupportedModel(Enum): SupportedModel.LLAMA4: hf_to_mcore_config_llama4, SupportedModel.QWEN3: hf_to_mcore_config_dense, SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, + SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, } # Registry for model initializers @@ -90,6 +93,7 @@ class SupportedModel(Enum): SupportedModel.LLAMA4: DenseModel, SupportedModel.QWEN3: DenseModel, SupportedModel.QWEN3_MOE: Qwen3MoEModel, + SupportedModel.QWEN2_5_VL: Qwen25VLModel, } # Registry for model forward functions @@ -103,6 +107,7 @@ class SupportedModel(Enum): SupportedModel.LLAMA4: gptmodel_forward, SupportedModel.QWEN3: gptmodel_forward, SupportedModel.QWEN3_MOE: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, SupportedModel.DEEPSEEK_V3: gptmodel_forward, } @@ -115,6 +120,7 @@ class SupportedModel(Enum): SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3, SupportedModel.QWEN3: McoreToHFWeightConverterDense, SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, + SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, } diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py index 153c49ff67c..5d1037681f4 100644 --- a/verl/models/mcore/saver.py +++ b/verl/models/mcore/saver.py @@ -467,6 +467,10 @@ def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_valu raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") +def merge_megatron_ckpt_gptmodel_qwen2_5_vl(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") + + def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") diff --git a/verl/models/mcore/weight_converter.py b/verl/models/mcore/weight_converter.py index cd620f2a897..53a4eee250a 100644 --- a/verl/models/mcore/weight_converter.py +++ b/verl/models/mcore/weight_converter.py @@ -147,6 +147,113 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis return convert_names, params +class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", + "language_model.decoder.final_layernorm.weight": "model.norm.weight", + "language_model.output_layer.weight": "lm_head.weight", + "vision_model.patch_embed.proj.weight": "visual.patch_embed.proj.weight", + "vision_model.decoder.final_layernorm.weight": "visual.merger.ln_q.weight", + "vision_model.projection.encoder.linear_fc1.weight": "visual.merger.mlp.0.weight", + "vision_model.projection.encoder.linear_fc1.bias": "visual.merger.mlp.0.bias", + "vision_model.projection.encoder.linear_fc2.weight": "visual.merger.mlp.2.weight", + "vision_model.projection.encoder.linear_fc2.bias": "visual.merger.mlp.2.bias", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "self_attention.linear_qkv.bias": ["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], + "self_attention.linear_qkv.weight": ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], + "self_attention.linear_proj.weight": "self_attn.o_proj.weight", + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + elif model_type == "vision_model": + name_map_after_layer = {"self_attention.linear_proj.weight": "attn.proj.weight", "self_attention.linear_proj.bias": "attn.proj.bias", "self_attention.linear_qkv.layer_norm_weight": "norm1.weight"} + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer, None) + if mapped_name is None: + assert "linear_qkv" in name_after_layer + assert len(params) == 3 + new_param = torch.cat(params, dim=0) + params = [new_param] + if "bias" in name_after_layer: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") + else: + convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + model_type, _, _, layer_number = name.split(".")[:4] + + convert_names = [] + if model_type == "language_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"model.layers.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"model.layers.{layer_number}.{mapped_name}") + + elif model_type == "vision_model": + name_map_after_layer = { + "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], + "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], + "mlp.linear_fc2.weight": "mlp.down_proj.weight", + "mlp.linear_fc2.bias": "mlp.down_proj.bias", + "mlp.linear_fc1.layer_norm_weight": "norm2.weight", + } + name_after_layer = ".".join(name.split(".")[-3:]) + mapped_name = name_map_after_layer.get(name_after_layer) + if isinstance(mapped_name, list): + assert len(params) == len(mapped_name) + for one in mapped_name: + convert_names.append(f"visual.blocks.{layer_number}.{one}") + else: + assert len(params) == 1 + convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") + else: + raise NotImplementedError(f"Unsupported model type: {model_type}") + return convert_names, params + + class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: # mcore diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py index 8f1e208534d..37dc70f1f7a 100644 --- a/verl/models/weight_loader_registry.py +++ b/verl/models/weight_loader_registry.py @@ -27,13 +27,14 @@ def get_weight_loader(arch: str): def get_weight_saver(arch: str): - from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_dpskv3, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen_moe + from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_dpskv3, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen2_5_vl, merge_megatron_ckpt_gptmodel_qwen_moe _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index ad3dfe5100c..3b6b42e69df 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -679,7 +679,7 @@ def broadcast_str_from_megatron_pp(obj: Any): return obj_output[0] -def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False): +def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, model_config, hf_config=None, convert_qkv_gate_up_by_simple_split=False): """ name: name of the parameter train_params: training parameters @@ -691,21 +691,31 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m """ from megatron.core import mpu + train_tp_size = mpu.get_tensor_model_parallel_world_size() if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: # if the tensor is qkv, for each param on tp, split into q, k, v # concat q, k, v separately. q_lst = [] k_lst = [] v_lst = [] - assert model_config.num_attention_heads % model_config.num_key_value_heads == 0 - num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads + num_attention_heads = model_config.num_attention_heads + num_key_value_heads = model_config.num_key_value_heads + if "vision_model" in name: + num_attention_heads = hf_config.vision_config.num_heads + num_key_value_heads = hf_config.vision_config.num_heads + assert num_attention_heads % num_key_value_heads == 0 + num_q_per_kv = num_attention_heads // num_key_value_heads assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] for infer_param in infer_params: - num_query_groups_per_partition = model_config.num_key_value_heads // mpu.get_tensor_model_parallel_world_size() + num_query_groups_per_partition = num_key_value_heads // train_tp_size for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition] + split_size = [ + kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + ] q, k, v = chunk.split(split_size) q_lst.append(q) k_lst.append(k) @@ -715,7 +725,7 @@ def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, m v = torch.cat(v_lst, dim=0) infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] - elif layer_name_mapping.get("gate_proj_layer_name") in name: + elif layer_name_mapping.get("gate_proj_layer_name") in name and "layer_norm" not in name and "vision_model.projection" not in name: # if the tensor is gate and proj gate_lst = [] up_lst = [] @@ -828,7 +838,7 @@ def tensor_generator(): else: params = [param] - merge_params = default_tp_concat_fn(layer_name_mapping, name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split) + merge_params = default_tp_concat_fn(layer_name_mapping, name, broad_pp_tensor, params, model_config, weight_converter.hf_config, convert_qkv_gate_up_by_simple_split) if not isinstance(merge_params, list): merge_params = [merge_params] converted_names, converted_params = weight_converter.convert_param(name, merge_params) @@ -844,7 +854,7 @@ def tensor_generator(): else: infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) - infer_params = default_tp_concat_fn(layer_name_mapping, cur_name, broad_pp_tensor, infer_params, model_config, convert_qkv_gate_up_by_simple_split) + infer_params = default_tp_concat_fn(layer_name_mapping, cur_name, broad_pp_tensor, infer_params, model_config, weight_converter.hf_config, convert_qkv_gate_up_by_simple_split) else: infer_params = broad_pp_tensor diff --git a/verl/utils/model.py b/verl/utils/model.py index 24482d06213..b00ef00e039 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -422,15 +422,13 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False): from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.serialization import StrictHandling - from megatron.core.models.gpt.gpt_model import GPTModel + + from verl.utils.megatron_utils import unwrap_model # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED strict = StrictHandling.ASSUME_OK_UNEXPECTED for model in parallel_model: - if isinstance(model.module, GPTModel): - ssd = model.module.sharded_state_dict() - else: - ssd = model.module.module.sharded_state_dict() + ssd = unwrap_model(model).sharded_state_dict() if is_value_model: for k in list(ssd.keys()): if "output_layer" in k: diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 5027d594fd3..af40d78e1d8 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -270,7 +270,11 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] if self.config.use_kl_loss: select_keys.append("ref_log_prob") - data = data.select(batch_keys=select_keys) + self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + data = data.select(select_keys, ["multi_modal_inputs"]) + else: + data = data.select(batch_keys=select_keys) return data.make_iterator( mini_batch_size=self.config.ppo_mini_batch_size, epochs=self.config.ppo_epochs, @@ -290,6 +294,10 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group()) # split into micro-batches mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor(list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"])))).to(torch.int64) indices = None if use_dynamic_bsz: @@ -329,7 +337,7 @@ def loss_func(output, data, meta_info): responses = data["responses"] response_length = responses.size(1) - attention_mask = data["attention_mask"] + attention_mask = data["attention_mask"].to(bool) response_mask = attention_mask[:, -response_length:] loss_agg_mode = self.config.loss_agg_mode @@ -395,9 +403,13 @@ def loss_func(output, data, meta_info): def forward_step(batch_iter, model): batch = next(batch_iter) input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] + attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] + multi_modal_inputs = {} + if "multi_modal_inputs" in batch: + for key in batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat([batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0) responses = batch["responses"] response_length = responses.size(1) label = copy.deepcopy(position_ids) @@ -427,7 +439,7 @@ def logits_processor(logits, label, label_mask): forward_fn = get_mcore_forward_fn(self.hf_config) - output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel, logits_processor=logits_processor, logits_processor_args=logits_processor_args) + output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel, multi_modal_inputs=multi_modal_inputs, logits_processor=logits_processor, logits_processor_args=logits_processor_args) if forward_only: meta_info = None @@ -466,6 +478,12 @@ def logits_processor(logits, label, label_mask): forward_only=forward_only, ) # loss_reduces contains the stats returned from loss_func + + if self.has_multi_modal_inputs: + data.batch.pop("multi_modal_inputs") + data.batch.pop("multi_modal_inputs_idx") + data.non_tensor_batch.pop("multi_modal_inputs") + losses_reduced = {"output": losses_reduced} if use_dynamic_bsz: losses_reduced["indices"] = indices diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index faf47c277d0..33187c475e8 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -214,7 +214,7 @@ def _build_rollout(self, trust_remote_code=False): layer_name_mapping = { "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.weight", + "gate_proj_layer_name": "linear_fc1.", } if self.config.rollout.name == "vllm": from torch.distributed.device_mesh import init_device_mesh diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 3f64e5b24f8..a52c11a583d 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -214,6 +214,11 @@ def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor(list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"])))).to(torch.int64) + indices = None if use_dynamic_bsz: assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" @@ -247,6 +252,11 @@ def forward_step(batch_iter, model): forward_fn = get_mcore_forward_fn(self.hf_config) + multi_modal_inputs = {} + if "multi_modal_inputs" in batch: + for key in batch["multi_modal_inputs"][0].keys(): + multi_modal_inputs[key] = torch.cat([batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0) + output = forward_fn( model, input_ids, @@ -254,6 +264,7 @@ def forward_step(batch_iter, model): position_ids, sequence_parallel=self.tf_config.sequence_parallel, value_model=True, + multi_modal_inputs=multi_modal_inputs, ) return output, loss_func @@ -283,6 +294,11 @@ def forward_step(batch_iter, model): micro_batch_size=1, # in use for pp = 1 forward_only=True, ) + + if self.has_multi_modal_inputs: + data.batch.pop("multi_modal_inputs") + data.batch.pop("multi_modal_inputs_idx") + data.non_tensor_batch.pop("multi_modal_inputs") # loss_reduces contains the stats returned from loss_func losses_reduced = {"output": losses_reduced} if use_dynamic_bsz: