Skip to content

Commit ac4032a

Browse files
authored
[worker] fix: get all multi_modal_inputs keys with in a microbatch (volcengine#3315)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Address the first issue in volcengine#3281 (comment) More work on top of volcengine#1999 Currently, the code gets the keys from the first row within the microbatch, This can go wrong if the dataset is a mixture of pure-text with multi-modal, where the first data in the microbatch is a pure-text one (no `pixel_values` or `image_grid_thw` exists in the key), and the microbatch still contains multi-modal data. This PR fixes this issue by collecting all available keys for `multi_modal_inputs` within the microbatch, and so that we can concatenate those multi-modal tensors together without ignoring some of them under the above situation. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
1 parent 03fabaa commit ac4032a

File tree

7 files changed

+64
-44
lines changed

7 files changed

+64
-44
lines changed

verl/utils/model.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,47 @@ def get_hf_auto_model_class(hf_config):
696696
return actor_module_class
697697

698698

699+
def extract_multi_modal_inputs(
700+
batch_data: list[dict[str, torch.Tensor]],
701+
indices: Optional[list[int]] = None,
702+
) -> dict[str, torch.Tensor | list[torch.Tensor]]:
703+
"""
704+
Extract and process multi-modal inputs from a batch.
705+
706+
Args:
707+
batch_data (list[dict[str, torch.Tensor]]): The batch containing potential multi-modal inputs
708+
indices (Optional[list[int]]): If provided, only extract inputs at these indices
709+
710+
Returns:
711+
dict[str, torch.Tensor | list[torch.Tensor]]: Processed multi-modal inputs ready for model consumption
712+
713+
"""
714+
multi_modal_inputs = {}
715+
multi_modal_inputs_collected = {}
716+
has_image_bound = False
717+
718+
selected_batch_data = batch_data
719+
if indices is not None:
720+
selected_batch_data = [batch_data[i] for i in indices if i < len(batch_data)]
721+
722+
for inputs in selected_batch_data:
723+
if "image_bound" in inputs:
724+
has_image_bound = True
725+
for key, value in inputs.items():
726+
if value is not None:
727+
if key not in multi_modal_inputs_collected:
728+
multi_modal_inputs_collected[key] = []
729+
multi_modal_inputs_collected[key].append(value)
730+
731+
for key, values in multi_modal_inputs_collected.items():
732+
if has_image_bound: # minicpm-o logic
733+
multi_modal_inputs[key] = values
734+
else:
735+
multi_modal_inputs[key] = torch.cat(values, dim=0)
736+
737+
return multi_modal_inputs
738+
739+
699740
@dataclass
700741
class CausalLMOutputForPPO(CausalLMOutputWithPast):
701742
log_probs: Optional[torch.FloatTensor] = None

verl/workers/actor/dp_actor.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,9 @@ def _forward_micro_batch(
9898
response_length = micro_batch["responses"].size(-1)
9999
multi_modal_inputs = {}
100100
if "multi_modal_inputs" in micro_batch.keys():
101-
if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic
102-
for key in micro_batch["multi_modal_inputs"][0].keys():
103-
multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]]
104-
else:
105-
for key in micro_batch["multi_modal_inputs"][0].keys():
106-
multi_modal_inputs[key] = torch.cat(
107-
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
108-
)
101+
from verl.utils.model import extract_multi_modal_inputs
102+
103+
multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])
109104

110105
with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
111106
input_ids = micro_batch["input_ids"]

verl/workers/actor/megatron_actor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,12 +466,10 @@ def forward_step(batch_iter, model):
466466

467467
multi_modal_inputs = {}
468468
if "multi_modal_inputs" in batch:
469-
for key in batch["multi_modal_inputs"][0].keys():
470-
idxs = batch["multi_modal_inputs_idx"]
471-
mmi = batch["multi_modal_inputs"]
472-
multi_modal_inputs[key] = torch.cat(
473-
[mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0
474-
)
469+
from verl.utils.model import extract_multi_modal_inputs
470+
471+
indices = batch.get("multi_modal_inputs_idx", None)
472+
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
475473
responses = batch["responses"]
476474
response_length = responses.size(1)
477475
label = position_ids.clone()

verl/workers/critic/dp_critic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,9 @@ def _forward_micro_batch(self, micro_batch):
5858
response_length = micro_batch["responses"].size(-1)
5959
multi_modal_inputs = {}
6060
if "multi_modal_inputs" in micro_batch.keys():
61-
for key in micro_batch["multi_modal_inputs"][0].keys():
62-
multi_modal_inputs[key] = torch.cat(
63-
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
64-
)
61+
from verl.utils.model import extract_multi_modal_inputs
62+
63+
multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])
6564

6665
with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):
6766
input_ids = micro_batch["input_ids"]

verl/workers/engine/fsdp/transformer_impl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,9 @@ def prepare_model_inputs(self, micro_batch: TensorDict):
644644

645645
multi_modal_inputs = {}
646646
if "multi_modal_inputs" in micro_batch.keys():
647-
if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic
648-
for key in micro_batch["multi_modal_inputs"][0].keys():
649-
multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]]
650-
else:
651-
for key in micro_batch["multi_modal_inputs"][0].keys():
652-
multi_modal_inputs[key] = torch.cat(
653-
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
654-
)
647+
from verl.utils.model import extract_multi_modal_inputs
648+
649+
multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"])
655650

656651
input_ids = micro_batch["input_ids"]
657652
attention_mask = micro_batch["attention_mask"]

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,14 @@ def _build_megatron_module(self):
165165
return module
166166

167167
def _build_optimizer(self):
168-
from verl.utils.megatron.optimizer import (
169-
get_megatron_optimizer,
170-
init_megatron_optim_config,
171-
)
168+
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
172169

173170
optim_config_megatron = init_megatron_optim_config(self.optimizer_config)
174171
optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron)
175172
return optimizer
176173

177174
def _build_lr_scheduler(self):
178-
from verl.utils.megatron.optimizer import (
179-
get_megatron_optimizer_param_scheduler,
180-
)
175+
from verl.utils.megatron.optimizer import get_megatron_optimizer_param_scheduler
181176

182177
optimizer_scheduler = get_megatron_optimizer_param_scheduler(
183178
optimizer=self.optimizer, config=self.optimizer_config
@@ -495,13 +490,11 @@ def prepare_model_inputs(self, batch: TensorDict):
495490
] # mcore patch recompute qwen2vl's pos ids during forward
496491

497492
multi_modal_inputs = {}
498-
if "multi_modal_inputs" in batch.keys():
499-
for key in batch["multi_modal_inputs"][0].keys():
500-
idxs = batch["multi_modal_inputs_idx"]
501-
mmi = batch["multi_modal_inputs"]
502-
multi_modal_inputs[key] = torch.cat(
503-
[mmi[idx].get(key).to(input_ids.device) for idx in idxs if mmi[idx].get(key) is not None], dim=0
504-
)
493+
if "multi_modal_inputs" in batch:
494+
from verl.utils.model import extract_multi_modal_inputs
495+
496+
indices = batch.get("multi_modal_inputs_idx", None)
497+
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
505498

506499
return {
507500
"input_ids": input_ids,

verl/workers/reward_model/megatron/reward_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,10 @@ def forward_step(batch_iter, model):
281281

282282
multi_modal_inputs = {}
283283
if "multi_modal_inputs" in batch:
284-
for key in batch["multi_modal_inputs"][0].keys():
285-
multi_modal_inputs[key] = torch.cat(
286-
[batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0
287-
)
284+
from verl.utils.model import extract_multi_modal_inputs
288285

286+
indices = batch.get("multi_modal_inputs_idx", None)
287+
multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices)
289288
output = forward_fn(
290289
model,
291290
input_ids,

0 commit comments

Comments
 (0)