Skip to content

Commit ab88ddc

Browse files
committed
fix image-text mixed data
1 parent dae06ce commit ab88ddc

File tree

9 files changed

+293
-34
lines changed

9 files changed

+293
-34
lines changed

verl/models/monkey_patch.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,49 @@
1515

1616
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
1717

18+
from ..utils.py_functional import is_transformers_version_greater_than
1819
from .transformers.flash_attention_utils import flash_attention_forward
19-
from .transformers.qwen2_vl import qwen2_vl_attn_forward
20+
from .transformers.qwen2_vl import (
21+
qwen2_vl_attn_forward,
22+
qwen2_vl_base_forward_new,
23+
qwen2_vl_forward_new,
24+
qwen2_vl_forward_old,
25+
)
2026

2127

2228
def apply_ulysses_patch(model_type: str) -> None:
2329
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
2430
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
2531
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
26-
try:
32+
if is_transformers_version_greater_than("4.53.0"):
33+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
34+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
35+
36+
Qwen2VLAttention.forward = qwen2_vl_attn_forward
37+
Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
38+
else:
2739
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
2840
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
29-
except ImportError: # transformers >= 4.52.4
41+
42+
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
43+
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
44+
45+
if is_transformers_version_greater_than("4.52.0"):
3046
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
31-
Qwen2_5_VLAttention as Qwen2_5_VLFlashAttention2,
47+
Qwen2_5_VLForConditionalGeneration,
48+
Qwen2_5_VLModel,
3249
)
33-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention as Qwen2VLFlashAttention2
50+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel
51+
52+
Qwen2VLModel.forward = qwen2_vl_base_forward_new
53+
Qwen2_5_VLModel.forward = qwen2_vl_base_forward_new
54+
Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_new
55+
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_new
56+
else:
57+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
58+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
3459

35-
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
36-
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
60+
Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_old
61+
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_old
3762
else:
3863
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")

verl/models/transformers/qwen2_vl.py

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,29 @@
1919

2020
import torch
2121

22+
from ...utils.py_functional import is_transformers_version_greater_than
2223
from .flash_attention_utils import flash_attention_forward
2324

2425

25-
try:
26+
if is_transformers_version_greater_than("4.52.0"):
2627
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
2728
Qwen2VLAttention,
29+
Qwen2VLCausalLMOutputWithPast,
30+
Qwen2VLForConditionalGeneration,
31+
Qwen2VLModel,
32+
Qwen2VLModelOutputWithPast,
2833
apply_multimodal_rotary_pos_emb,
2934
repeat_kv,
3035
)
3136
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
32-
except ImportError:
33-
pass
37+
else:
38+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
39+
Qwen2VLAttention,
40+
Qwen2VLCausalLMOutputWithPast,
41+
Qwen2VLForConditionalGeneration,
42+
apply_multimodal_rotary_pos_emb,
43+
repeat_kv,
44+
)
3445

3546

3647
def get_rope_index(
@@ -183,3 +194,184 @@ def qwen2_vl_attn_forward(
183194
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
184195
attn_output = self.o_proj(attn_output)
185196
return attn_output, None, None
197+
198+
199+
def _get_input_embeds(
200+
model: "Qwen2VLModel",
201+
input_ids: torch.LongTensor,
202+
attention_mask: Optional[torch.Tensor] = None,
203+
pixel_values: Optional[torch.FloatTensor] = None,
204+
pixel_values_videos: Optional[torch.FloatTensor] = None,
205+
image_grid_thw: Optional[torch.LongTensor] = None,
206+
video_grid_thw: Optional[torch.LongTensor] = None,
207+
):
208+
inputs_embeds = model.get_input_embeddings()(input_ids)
209+
if pixel_values is not None:
210+
pixel_values = pixel_values.type(model.visual.dtype)
211+
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
212+
n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
213+
n_image_features = image_embeds.shape[0]
214+
if n_image_tokens != n_image_features:
215+
raise ValueError(
216+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
217+
)
218+
219+
mask = input_ids == model.config.image_token_id
220+
mask_unsqueezed = mask.unsqueeze(-1)
221+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
222+
image_mask = mask_expanded.to(inputs_embeds.device)
223+
224+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
225+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
226+
227+
if pixel_values_videos is not None:
228+
pixel_values_videos = pixel_values_videos.type(model.visual.dtype)
229+
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
230+
n_video_tokens = (input_ids == model.config.video_token_id).sum().item()
231+
n_video_features = video_embeds.shape[0]
232+
if n_video_tokens != n_video_features:
233+
raise ValueError(
234+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
235+
)
236+
237+
mask = input_ids == model.config.video_token_id
238+
mask_unsqueezed = mask.unsqueeze(-1)
239+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
240+
video_mask = mask_expanded.to(inputs_embeds.device)
241+
242+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
243+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
244+
245+
if pixel_values is None and pixel_values_videos is None:
246+
pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
247+
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
248+
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
249+
inputs_embeds += 0.0 * image_embeds.mean()
250+
251+
if attention_mask is not None:
252+
attention_mask = attention_mask.to(inputs_embeds.device)
253+
254+
return inputs_embeds, attention_mask
255+
256+
257+
def qwen2_vl_forward_old(
258+
self: "Qwen2VLForConditionalGeneration",
259+
input_ids: torch.LongTensor,
260+
attention_mask: Optional[torch.Tensor] = None,
261+
position_ids: Optional[torch.LongTensor] = None,
262+
labels: Optional[torch.LongTensor] = None,
263+
pixel_values: Optional[torch.FloatTensor] = None,
264+
pixel_values_videos: Optional[torch.FloatTensor] = None,
265+
image_grid_thw: Optional[torch.LongTensor] = None,
266+
video_grid_thw: Optional[torch.LongTensor] = None,
267+
**kwargs,
268+
) -> "Qwen2VLCausalLMOutputWithPast":
269+
inputs_embeds, attention_mask = _get_input_embeds(
270+
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
271+
)
272+
outputs = self.model(
273+
input_ids=None,
274+
pixel_values=pixel_values,
275+
pixel_values_videos=pixel_values_videos,
276+
image_grid_thw=image_grid_thw,
277+
video_grid_thw=video_grid_thw,
278+
position_ids=position_ids,
279+
attention_mask=attention_mask,
280+
past_key_values=None,
281+
inputs_embeds=inputs_embeds,
282+
use_cache=False,
283+
output_attentions=False,
284+
output_hidden_states=False,
285+
return_dict=True,
286+
cache_position=None,
287+
)
288+
hidden_states = outputs[0]
289+
logits = self.lm_head(hidden_states)
290+
291+
return Qwen2VLCausalLMOutputWithPast(
292+
loss=None,
293+
logits=logits,
294+
past_key_values=None,
295+
hidden_states=None,
296+
attentions=None,
297+
rope_deltas=None,
298+
)
299+
300+
301+
def qwen2_vl_base_forward_new(
302+
self: "Qwen2VLModel",
303+
input_ids: torch.LongTensor,
304+
attention_mask: Optional[torch.Tensor] = None,
305+
position_ids: Optional[torch.LongTensor] = None,
306+
labels: Optional[torch.LongTensor] = None,
307+
pixel_values: Optional[torch.FloatTensor] = None,
308+
pixel_values_videos: Optional[torch.FloatTensor] = None,
309+
image_grid_thw: Optional[torch.LongTensor] = None,
310+
video_grid_thw: Optional[torch.LongTensor] = None,
311+
**kwargs,
312+
):
313+
inputs_embeds, attention_mask = _get_input_embeds(
314+
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
315+
)
316+
outputs = self.language_model(
317+
input_ids=None,
318+
position_ids=position_ids,
319+
attention_mask=attention_mask,
320+
past_key_values=None,
321+
inputs_embeds=inputs_embeds,
322+
use_cache=False,
323+
output_attentions=False,
324+
output_hidden_states=False,
325+
return_dict=True,
326+
cache_position=None,
327+
)
328+
329+
output = Qwen2VLModelOutputWithPast(
330+
last_hidden_state=outputs.last_hidden_state,
331+
past_key_values=outputs.past_key_values,
332+
hidden_states=outputs.hidden_states,
333+
attentions=outputs.attentions,
334+
rope_deltas=None,
335+
)
336+
return output
337+
338+
339+
def qwen2_vl_forward_new(
340+
self: "Qwen2VLForConditionalGeneration",
341+
input_ids: torch.LongTensor,
342+
attention_mask: Optional[torch.Tensor] = None,
343+
position_ids: Optional[torch.LongTensor] = None,
344+
labels: Optional[torch.LongTensor] = None,
345+
pixel_values: Optional[torch.FloatTensor] = None,
346+
pixel_values_videos: Optional[torch.FloatTensor] = None,
347+
image_grid_thw: Optional[torch.LongTensor] = None,
348+
video_grid_thw: Optional[torch.LongTensor] = None,
349+
**kwargs,
350+
) -> "Qwen2VLCausalLMOutputWithPast":
351+
outputs = self.model(
352+
input_ids=input_ids,
353+
pixel_values=pixel_values,
354+
pixel_values_videos=pixel_values_videos,
355+
image_grid_thw=image_grid_thw,
356+
video_grid_thw=video_grid_thw,
357+
position_ids=position_ids,
358+
attention_mask=attention_mask,
359+
past_key_values=None,
360+
inputs_embeds=None,
361+
use_cache=False,
362+
output_attentions=False,
363+
output_hidden_states=False,
364+
return_dict=True,
365+
cache_position=None,
366+
)
367+
hidden_states = outputs[0]
368+
logits = self.lm_head(hidden_states)
369+
370+
return Qwen2VLCausalLMOutputWithPast(
371+
loss=None,
372+
logits=logits,
373+
past_key_values=None,
374+
hidden_states=None,
375+
attentions=None,
376+
rope_deltas=None,
377+
)

verl/utils/dataset.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,20 @@ def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]:
156156

157157
def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool:
158158
messages = self._build_messages(example)
159-
processing_class = self.processor if self.processor is not None else self.tokenizer
160-
return (
161-
len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length
162-
)
159+
if self.image_key in example:
160+
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
161+
images = example[self.image_key] or []
162+
if self.image_dir is not None and len(images) != 0 and isinstance(images[0], str): # image paths
163+
images = [os.path.join(self.image_dir, image) for image in images]
164+
165+
resized_images = [
166+
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) for image in images
167+
] or None
168+
model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt")
169+
return model_inputs["input_ids"].size(-1) <= self.max_prompt_length
170+
else:
171+
input_ids = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
172+
return len(input_ids) <= self.max_prompt_length
163173

164174
def __len__(self):
165175
return len(self.dataset)
@@ -176,7 +186,7 @@ def __getitem__(self, index):
176186

177187
resized_images = [
178188
process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) for image in images
179-
]
189+
] or None
180190
model_inputs = self.processor(resized_images, [prompt], add_special_tokens=False, return_tensors="pt")
181191
input_ids = model_inputs.pop("input_ids")[0]
182192
attention_mask = model_inputs.pop("attention_mask")[0]

verl/utils/py_functional.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Contain small python utility functions
1616
"""
1717

18+
import importlib.metadata
1819
import importlib.util
1920
import re
2021
from contextlib import contextmanager
@@ -24,6 +25,7 @@
2425
import numpy as np
2526
import yaml
2627
from codetiming import Timer
28+
from packaging import version
2729
from yaml import Dumper
2830

2931

@@ -53,6 +55,18 @@ def is_package_available(name: str) -> bool:
5355
return importlib.util.find_spec(name) is not None
5456

5557

58+
def get_package_version(name: str) -> "version.Version":
59+
try:
60+
return version.parse(importlib.metadata.version(name))
61+
except Exception:
62+
return version.parse("0.0.0")
63+
64+
65+
@lru_cache
66+
def is_transformers_version_greater_than(content: str):
67+
return get_package_version("transformers") >= version.parse(content)
68+
69+
5670
def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
5771
"""Union two dict. Will throw an error if there is an item not the same object with the same key."""
5872
for key in dict2.keys():

verl/utils/tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_processor(model_path: str, override_chat_template: Optional[str] = None,
4444
processor.chat_template = override_chat_template
4545

4646
# Avoid load tokenizer, see:
47-
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
47+
# https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/auto/processing_auto.py#L386
4848
if processor is not None and "Processor" not in processor.__class__.__name__:
4949
processor = None
5050

verl/workers/actor/dp_actor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,17 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature
7171
if position_ids.dim() == 3: # qwen2vl mrope
7272
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
7373

74-
multi_modal_inputs = {}
74+
multi_modal_inputs = defaultdict(list)
7575
if "multi_modal_inputs" in micro_batch:
76-
for key in micro_batch["multi_modal_inputs"][0].keys():
77-
multi_modal_inputs[key] = torch.cat(
78-
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
79-
)
76+
for input_dict in micro_batch["multi_modal_inputs"]:
77+
for key, value in input_dict.items():
78+
multi_modal_inputs[key].append(value)
79+
80+
for key, value in multi_modal_inputs.items():
81+
if len(value) != 0:
82+
multi_modal_inputs[key] = torch.cat(value, dim=0)
83+
else:
84+
multi_modal_inputs[key] = None
8085

8186
if self.config.padding_free:
8287
input_ids_rmpad, indices, *_ = unpad_input(

verl/workers/critic/dp_critic.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,17 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te
5959
if position_ids.dim() == 3: # qwen2vl mrope
6060
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
6161

62-
multi_modal_inputs = {}
62+
multi_modal_inputs = defaultdict(list)
6363
if "multi_modal_inputs" in micro_batch:
64-
for key in micro_batch["multi_modal_inputs"][0].keys():
65-
multi_modal_inputs[key] = torch.cat(
66-
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
67-
)
64+
for input_dict in micro_batch["multi_modal_inputs"]:
65+
for key, value in input_dict.items():
66+
multi_modal_inputs[key].append(value)
67+
68+
for key, value in multi_modal_inputs.items():
69+
if len(value) != 0:
70+
multi_modal_inputs[key] = torch.cat(value, dim=0)
71+
else:
72+
multi_modal_inputs[key] = None
6873

6974
if self.config.padding_free:
7075
input_ids_rmpad, indices, *_ = unpad_input(

0 commit comments

Comments
 (0)