Skip to content

Commit ef5905e

Browse files
author
huangjiaming
committed
fix qwen2_vl on plain-text data and plain-text-image-mixed data
1 parent a1a152e commit ef5905e

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

verl/models/mcore/model_forward.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def gptmodel_forward_qwen2_5_vl(
7979
assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet"
8080
pre_process = unwrap_model(model).pre_process
8181
post_process = unwrap_model(model).post_process
82+
pixel_values = multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None
83+
image_grid_thw = multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None
8284
if pack_seqs:
8385
batch_size, seq_len = attention_mask.shape[:2]
8486
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)
@@ -88,8 +90,8 @@ def gptmodel_forward_qwen2_5_vl(
8890
attention_mask=None,
8991
position_ids=position_ids,
9092
packed_seq_params=packed_seq_params,
91-
pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device),
92-
image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device),
93+
pixel_values=pixel_values,
94+
image_grid_thw=image_grid_thw,
9395
)
9496

9597
if post_process and logits_processor is not None:
@@ -105,8 +107,8 @@ def gptmodel_forward_qwen2_5_vl(
105107
input_ids=new_input_ids,
106108
position_ids=new_position_ids,
107109
attention_mask=new_attention_mask,
108-
pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device),
109-
image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device),
110+
pixel_values=pixel_values,
111+
image_grid_thw=image_grid_thw,
110112
)
111113
output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process)
112114
if value_model and post_process:

verl/workers/actor/megatron_actor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def forward_step(batch_iter, model):
409409
multi_modal_inputs = {}
410410
if "multi_modal_inputs" in batch:
411411
for key in batch["multi_modal_inputs"][0].keys():
412-
multi_modal_inputs[key] = torch.cat([batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0)
412+
multi_modal_inputs[key] = torch.cat(list(filter(lambda x: x != None, [batch["multi_modal_inputs"][i].get(key) for i in batch["multi_modal_inputs_idx"]])), dim=0)
413413
responses = batch["responses"]
414414
response_length = responses.size(1)
415415
label = copy.deepcopy(position_ids)

0 commit comments

Comments
 (0)