Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
592d619
offline weight converter
ISEEKYAN Apr 24, 2025
ff9d367
forward with geo3k.
ISEEKYAN Apr 24, 2025
87f508e
scrips
ISEEKYAN Apr 24, 2025
959ce5a
qwen25vl mcore weight converter
ISEEKYAN Apr 26, 2025
c7f177a
fix qwen2.5vl model converter
ISEEKYAN Apr 26, 2025
5b6eaad
sequence packing
ISEEKYAN Apr 27, 2025
ca05a00
tmp
ISEEKYAN Apr 27, 2025
d53de0e
clean
ISEEKYAN Apr 27, 2025
068d82f
flash
ISEEKYAN Apr 27, 2025
1d3a39b
qwen pp
ISEEKYAN May 8, 2025
b89c6b7
fix PP
ISEEKYAN May 8, 2025
406d80e
support 7b and more qwen25vl models
ISEEKYAN May 9, 2025
dd16c7f
enable sp
ISEEKYAN May 9, 2025
dc0205c
align some configs
ISEEKYAN May 12, 2025
bf541b3
Merge commit '867d3024bf7af6aee2cd785cfd573aec561f212d' into mcore_qw…
ISEEKYAN May 27, 2025
75dd567
Merge commit '04acd09d65900521e8019adefd10308220cb7ee2' into mcore_qw…
ISEEKYAN May 27, 2025
d1f5320
Merge commit '02862103babdd0df4fe70d9b236926fcc02bac27' into mcore_qw…
ISEEKYAN May 27, 2025
95ebb55
Merge commit '7d26d7359e17937d2590093f51b3e9de2e5e131d' into mcore_qw…
ISEEKYAN May 27, 2025
cffa9c1
fix
ISEEKYAN May 27, 2025
a2a6dba
Merge branch 'main' into mcore_qwen25vl_tmp_update0527_v6
ISEEKYAN May 27, 2025
8d6ac6c
clean the implementation of qwen25vl
ISEEKYAN May 31, 2025
ecc7c9f
clean
ISEEKYAN May 31, 2025
4ffe705
add copyright
ISEEKYAN May 31, 2025
d3b829d
clean
ISEEKYAN May 31, 2025
91f9692
Merge branch 'main' into mcore_qwen25vl_clean
ISEEKYAN Jun 3, 2025
b50e6be
add example
ISEEKYAN Jun 3, 2025
ae83ce5
add ci
ISEEKYAN Jun 3, 2025
96ad63a
fix ci
ISEEKYAN Jun 4, 2025
eeba24d
small fix
ISEEKYAN Jun 4, 2025
94871b1
change the way converter_hf_to_mcore loading model
ISEEKYAN Jun 4, 2025
e4527a0
Merge branch 'main' into mcore_qwen25vl
ISEEKYAN Jun 5, 2025
aadd5dc
Merge branch 'main' into mcore_qwen25vl
ISEEKYAN Jun 5, 2025
b23a6fa
Merge branch 'main' into mcore_qwen25vl
ISEEKYAN Jun 6, 2025
21dcbd8
fix vpp for ci
ISEEKYAN Jun 7, 2025
c9820a8
Merge branch 'main' into mcore_qwen25vl
ISEEKYAN Jun 9, 2025
c0b61e1
fix ci
ISEEKYAN Jun 9, 2025
c89c54f
fix pipeline parallel
ISEEKYAN Jun 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/e2e_ppo_trainer_megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ on:
# Entrypoints
- ".github/workflows/e2e_ppo_trainer_megatron.yml"
- "examples/data_preprocess/gsm8k.py"
- "examples/data_preprocess/geo3k.py"
- "tests/e2e/run_ppo_trainer_megatron.sh"
- "verl/trainer/main_ppo.py"
- "verl/trainer/config/ppo_megatron_trainer.yaml"
Expand Down Expand Up @@ -230,4 +231,36 @@ jobs:
- name: clean up
run: |
rm -rf checkpoints
e2e_ppo_trainer_megatron-qwen2_5vl-3b:
runs-on: [L20x8]
timeout-minutes: 60 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install --no-deps -e .[test]
- name: Prepare Geo3k dataset
run: |
python3 examples/data_preprocess/geo3k.py
- name: Prepare dist_ckpt of Qwen2.5-VL-3B, only supports dist_ckpt
run: |
python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-VL-3B-Instruct --output_path checkpoints/verl-test/qwen2.5-vl-3b-megatron
- name: Running Geo3k E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen)
run: |
ray stop --force
TRAIN_FILES=${HOME}/data/geo3k/train.parquet VAL_FILES=${HOME}/data/geo3k/test.parquet MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False SKIP_SAVE_HF_MODEL=1 COMMON_PP=4 COMMON_CP=1 COMMON_TP=2 bash tests/e2e/run_ppo_trainer_megatron.sh actor_rollout_ref.actor.megatron.use_dist_checkpointing=true actor_rollout_ref.ref.megatron.use_dist_checkpointing=true actor_rollout_ref.actor.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-vl-3b-megatron actor_rollout_ref.ref.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-vl-3b-megatron
- name: clean up
run: |
rm -rf checkpoints

56 changes: 56 additions & 0 deletions examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
set -x
ENGINE=${1:-vllm}
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping

HF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct
DIST_CKPT_PATH=${DIST_CKPT_PATH}

# convert HF model to verl format
# python scripts/converter_hf_to_verl.py --hf_model_path $HF_MODEL_PATH --output_dir $DIST_CKPT_PATH

train_path=/data/geo3k/train.parquet
test_path=/data/geo3k/test.parquet

python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \
data.train_files="$train_path" \
data.val_files="$test_path" \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$HF_MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=$ENGINE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_7b_megatron' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
125 changes: 109 additions & 16 deletions scripts/converter_hf_to_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from megatron.core.dist_checkpointing.serialization import StrictHandling
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig

from verl.models.mcore import hf_to_mcore_config
from verl.utils.megatron_utils import get_model
Expand Down Expand Up @@ -146,22 +146,108 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
model.output_layer.weight.copy_(hf_model.lm_head.weight)


def safe_copy(
src_tensor: torch.Tensor,
dst_tensor: torch.Tensor,
skip_dtype_assert: bool = False,
):
if not skip_dtype_assert:
if src_tensor.dtype != dst_tensor.dtype:
raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}")
assert src_tensor.shape == dst_tensor.shape
dst_tensor.data.copy_(src_tensor.data)
return src_tensor.numel()


@torch.inference_mode()
def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config):
mgmodel = mgmodel.bfloat16()
hfmodel = hfmodel.bfloat16()
num_attention_heads = hf_config.num_attention_heads
num_query_groups = hf_config.num_key_value_heads
hidden_size = hf_config.hidden_size
head_dim = hidden_size // num_attention_heads

# 1. vision model
hfvision = hfmodel.visual
mgvision = mgmodel.vision_model
vision_hidden_size = mgvision.config.hidden_size
vision_num_query_groups = mgvision.config.num_query_groups
vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads
copied_numel = 0
safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq)
copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight)
for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers):
# norm1 --> linear_qkv.norm
copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight)
# norm2 --> mlp.linear_fc1.norm
copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight)
# qkv --> self_attention.linear_qkv
converted_weight = hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size).transpose(0, 1).flatten(1, 2).reshape(-1, vision_hidden_size).contiguous()
copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight)
converted_bias = hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1).transpose(0, 1).flatten(1, 2).view(-1).contiguous()
copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias)
# proj --> self_attention.linear_proj
copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight)
copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias)
# mlp --> mlp: gate
fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight])
fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias])
copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight)
copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias)
copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight)
copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias)

# 2. vision projector
hfprojector = hfvision.merger
mgprojector = mgvision.projection
copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight)

copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight)
copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias)
copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight)
copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias)
n_params = sum([t.numel() for t in hfvision.state_dict().values()])
assert n_params == copied_numel
# 3. llm [just Qwen2]
hfllm = hfmodel.model
mgllm = mgmodel.language_model
copied_numel = 0
copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight)
for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers):
copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight)

q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous()
copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight)

q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1)
k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1)
v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1)
qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous()
copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias)
copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight)

fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight])
copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight)

copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight)
copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight)

copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight)
if not hf_config.tie_word_embeddings:
safe_copy(hfmodel.lm_head.weight, mgllm.output_layer.weight)

n_params = sum([t.numel() for t in hfllm.state_dict().values()])

assert n_params == copied_numel


@torch.no_grad()
def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig):
warnings.warn("MTP model is not supported yet", stacklevel=2)

def safe_copy(
src_tensor: torch.Tensor,
dst_tensor: torch.Tensor,
skip_dtype_assert: bool = False,
):
if not skip_dtype_assert:
if src_tensor.dtype != dst_tensor.dtype:
raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}")
assert src_tensor.shape == dst_tensor.shape
dst_tensor.data.copy_(src_tensor.data)
return src_tensor.numel()

model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)):
print(layer_idx)
Expand Down Expand Up @@ -264,14 +350,21 @@ def megatron_model_provider(pre_process, post_process):

with warnings.catch_warnings():
warnings.simplefilter("ignore")
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText

# init hf model
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code)
hf_state_dict = hf_model.state_dict()
try:
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code)
hf_state_dict = hf_model.state_dict()
except Exception:
hf_model = AutoModelForImageTextToText.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code)
hf_state_dict = hf_model.state_dict()

# load hf state dict to megatron model
if "Qwen2MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config)
elif "DeepseekV3ForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)
elif "Qwen3MoeForCausalLM" in hf_config.architectures:
Expand Down
15 changes: 13 additions & 2 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from megatron.core import parallel_state as mpu
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
from transformers import PretrainedConfig

Expand All @@ -36,7 +37,6 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
Returns:
TransformerConfig with common parameters
"""
from megatron.core import parallel_state as mpu

# Common parallel state parameters
overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
Expand All @@ -54,6 +54,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
"hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0),
"kv_channels": getattr(hf_config, "head_dim", None),
"layernorm_epsilon": hf_config.rms_norm_eps,
"add_bias_linear": False,
# Activation and normalization
"activation_func": F.silu,
"normalization": "RMSNorm",
Expand Down Expand Up @@ -266,7 +267,17 @@ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, *

def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
# Qwen2_5_VLForConditionalGeneration
raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet")

args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
add_bias_linear=False,
# qwen specific
add_qkv_bias=True,
mrope_section=hf_config.rope_scaling["mrope_section"],
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)


def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
Expand Down
55 changes: 52 additions & 3 deletions verl/models/mcore/model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def gptmodel_forward(
pack_seqs=True,
logits_processor=None,
logits_processor_args: dict = None,
**kwargs,
):
"""Default forward pass for GPT models with optional sequence packing."""
pre_process = unwrap_model(model).pre_process
Expand Down Expand Up @@ -60,6 +61,54 @@ def gptmodel_forward(
return output


def gptmodel_forward_qwen2_5_vl(*args, **kwargs):
"""Forward pass for Qwen2.5 VL model (not implemented)."""
raise NotImplementedError("VLM is not supported yet")
def gptmodel_forward_qwen2_5_vl(
model,
input_ids,
attention_mask,
position_ids,
sequence_parallel,
value_model=False,
pack_seqs=True,
multi_modal_inputs=None,
logits_processor=None,
logits_processor_args: dict = None,
**kwargs,
):
from megatron.core import parallel_state as mpu

assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet"
pre_process = unwrap_model(model).pre_process
post_process = unwrap_model(model).post_process
if pack_seqs:
batch_size, seq_len = attention_mask.shape[:2]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
packed_seq_params=packed_seq_params,
pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device),
image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device),
)

if post_process and logits_processor is not None:
args = {k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] for k, v in logits_processor_args.items()}
output_dict = logits_processor(output_orig, **args)
output = {k: postprocess_packed_seqs(v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) for k, v in output_dict.items()}
else:
output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process)
else:
batch_size, sequence_length = attention_mask.shape
new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process)
output = model(
input_ids=new_input_ids,
position_ids=new_position_ids,
attention_mask=new_attention_mask,
pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device),
image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device),
)
output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process)
if value_model and post_process:
output = output[..., 0]
return output
Loading
Loading