Skip to content

Commit 4f9d2b4

Browse files
authored
[megatron] fix: fix qwen2_vl on plain-text data and mix data of plain-text and image-text (volcengine#1999)
### Checklist Before Starting - [ ] Searched for similar PR(s). - [ ] Checked PR Title format - [ ] In format of: [modules] type: Title - [ ] modules are in `fsdp, megatron, sglang, vllm, rollout, trainer, tests, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc` - [ ] type is in `feat, fix, refactor, chore` - [ ] can involve multiple modules, seperated by `,` or space, like `[megatron, fsdp, doc] feat: xxx` ### What does this PR do? fix qwen2_vl on plain-text data and mix data of plain-text and image-text, refer to volcengine#1286 ### Test test on gsm8k dataset and mix data of gsm8k and geo3k. ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] New CI unit test(s) are added to cover the code path. - [ ] Rely on existing unit tests on CI that covers the code path.
1 parent cc6c9c0 commit 4f9d2b4

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

verl/models/mcore/model_forward.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def gptmodel_forward_qwen2_5_vl(
7979
assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet"
8080
pre_process = unwrap_model(model).pre_process
8181
post_process = unwrap_model(model).post_process
82+
pixel_values = multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None
83+
image_grid_thw = multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None
8284
if pack_seqs:
8385
batch_size, seq_len = attention_mask.shape[:2]
8486
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)
@@ -88,8 +90,8 @@ def gptmodel_forward_qwen2_5_vl(
8890
attention_mask=None,
8991
position_ids=position_ids,
9092
packed_seq_params=packed_seq_params,
91-
pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device),
92-
image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device),
93+
pixel_values=pixel_values,
94+
image_grid_thw=image_grid_thw,
9395
)
9496

9597
if post_process and logits_processor is not None:
@@ -105,8 +107,8 @@ def gptmodel_forward_qwen2_5_vl(
105107
input_ids=new_input_ids,
106108
position_ids=new_position_ids,
107109
attention_mask=new_attention_mask,
108-
pixel_values=multi_modal_inputs["pixel_values"].to(input_ids.device),
109-
image_grid_thw=multi_modal_inputs["image_grid_thw"].to(input_ids.device),
110+
pixel_values=pixel_values,
111+
image_grid_thw=image_grid_thw,
110112
)
111113
output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process)
112114
if value_model and post_process:

verl/workers/actor/megatron_actor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,9 @@ def forward_step(batch_iter, model):
413413
multi_modal_inputs = {}
414414
if "multi_modal_inputs" in batch:
415415
for key in batch["multi_modal_inputs"][0].keys():
416-
multi_modal_inputs[key] = torch.cat([batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0)
416+
idxs = batch["multi_modal_inputs_idx"]
417+
mmi = batch["multi_modal_inputs"]
418+
multi_modal_inputs[key] = torch.cat([mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0)
417419
responses = batch["responses"]
418420
response_length = responses.size(1)
419421
label = copy.deepcopy(position_ids)

0 commit comments

Comments
 (0)