|
19 | 19 |
|
20 | 20 | import torch |
21 | 21 |
|
| 22 | +from ...utils.py_functional import is_transformers_version_greater_than |
22 | 23 | from .flash_attention_utils import flash_attention_forward |
23 | 24 |
|
24 | 25 |
|
25 | | -try: |
| 26 | +if is_transformers_version_greater_than("4.52.0"): |
26 | 27 | from transformers.models.qwen2_vl.modeling_qwen2_vl import ( |
27 | 28 | Qwen2VLAttention, |
| 29 | + Qwen2VLCausalLMOutputWithPast, |
| 30 | + Qwen2VLForConditionalGeneration, |
| 31 | + Qwen2VLModel, |
| 32 | + Qwen2VLModelOutputWithPast, |
28 | 33 | apply_multimodal_rotary_pos_emb, |
29 | 34 | repeat_kv, |
30 | 35 | ) |
31 | 36 | from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor |
32 | | -except ImportError: |
33 | | - pass |
| 37 | +else: |
| 38 | + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( |
| 39 | + Qwen2VLAttention, |
| 40 | + Qwen2VLCausalLMOutputWithPast, |
| 41 | + Qwen2VLForConditionalGeneration, |
| 42 | + apply_multimodal_rotary_pos_emb, |
| 43 | + repeat_kv, |
| 44 | + ) |
34 | 45 |
|
35 | 46 |
|
36 | 47 | def get_rope_index( |
@@ -183,3 +194,184 @@ def qwen2_vl_attn_forward( |
183 | 194 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
184 | 195 | attn_output = self.o_proj(attn_output) |
185 | 196 | return attn_output, None, None |
| 197 | + |
| 198 | + |
| 199 | +def _get_input_embeds( |
| 200 | + model: "Qwen2VLModel", |
| 201 | + input_ids: torch.LongTensor, |
| 202 | + attention_mask: Optional[torch.Tensor] = None, |
| 203 | + pixel_values: Optional[torch.FloatTensor] = None, |
| 204 | + pixel_values_videos: Optional[torch.FloatTensor] = None, |
| 205 | + image_grid_thw: Optional[torch.LongTensor] = None, |
| 206 | + video_grid_thw: Optional[torch.LongTensor] = None, |
| 207 | +): |
| 208 | + inputs_embeds = model.get_input_embeddings()(input_ids) |
| 209 | + if pixel_values is not None: |
| 210 | + pixel_values = pixel_values.type(model.visual.dtype) |
| 211 | + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) |
| 212 | + n_image_tokens = (input_ids == model.config.image_token_id).sum().item() |
| 213 | + n_image_features = image_embeds.shape[0] |
| 214 | + if n_image_tokens != n_image_features: |
| 215 | + raise ValueError( |
| 216 | + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| 217 | + ) |
| 218 | + |
| 219 | + mask = input_ids == model.config.image_token_id |
| 220 | + mask_unsqueezed = mask.unsqueeze(-1) |
| 221 | + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| 222 | + image_mask = mask_expanded.to(inputs_embeds.device) |
| 223 | + |
| 224 | + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 225 | + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| 226 | + |
| 227 | + if pixel_values_videos is not None: |
| 228 | + pixel_values_videos = pixel_values_videos.type(model.visual.dtype) |
| 229 | + video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw) |
| 230 | + n_video_tokens = (input_ids == model.config.video_token_id).sum().item() |
| 231 | + n_video_features = video_embeds.shape[0] |
| 232 | + if n_video_tokens != n_video_features: |
| 233 | + raise ValueError( |
| 234 | + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| 235 | + ) |
| 236 | + |
| 237 | + mask = input_ids == model.config.video_token_id |
| 238 | + mask_unsqueezed = mask.unsqueeze(-1) |
| 239 | + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| 240 | + video_mask = mask_expanded.to(inputs_embeds.device) |
| 241 | + |
| 242 | + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 243 | + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
| 244 | + |
| 245 | + if pixel_values is None and pixel_values_videos is None: |
| 246 | + pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device) |
| 247 | + image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device) |
| 248 | + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) |
| 249 | + inputs_embeds += 0.0 * image_embeds.mean() |
| 250 | + |
| 251 | + if attention_mask is not None: |
| 252 | + attention_mask = attention_mask.to(inputs_embeds.device) |
| 253 | + |
| 254 | + return inputs_embeds, attention_mask |
| 255 | + |
| 256 | + |
| 257 | +def qwen2_vl_forward_old( |
| 258 | + self: "Qwen2VLForConditionalGeneration", |
| 259 | + input_ids: torch.LongTensor, |
| 260 | + attention_mask: Optional[torch.Tensor] = None, |
| 261 | + position_ids: Optional[torch.LongTensor] = None, |
| 262 | + labels: Optional[torch.LongTensor] = None, |
| 263 | + pixel_values: Optional[torch.FloatTensor] = None, |
| 264 | + pixel_values_videos: Optional[torch.FloatTensor] = None, |
| 265 | + image_grid_thw: Optional[torch.LongTensor] = None, |
| 266 | + video_grid_thw: Optional[torch.LongTensor] = None, |
| 267 | + **kwargs, |
| 268 | +) -> "Qwen2VLCausalLMOutputWithPast": |
| 269 | + inputs_embeds, attention_mask = _get_input_embeds( |
| 270 | + self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw |
| 271 | + ) |
| 272 | + outputs = self.model( |
| 273 | + input_ids=None, |
| 274 | + pixel_values=pixel_values, |
| 275 | + pixel_values_videos=pixel_values_videos, |
| 276 | + image_grid_thw=image_grid_thw, |
| 277 | + video_grid_thw=video_grid_thw, |
| 278 | + position_ids=position_ids, |
| 279 | + attention_mask=attention_mask, |
| 280 | + past_key_values=None, |
| 281 | + inputs_embeds=inputs_embeds, |
| 282 | + use_cache=False, |
| 283 | + output_attentions=False, |
| 284 | + output_hidden_states=False, |
| 285 | + return_dict=True, |
| 286 | + cache_position=None, |
| 287 | + ) |
| 288 | + hidden_states = outputs[0] |
| 289 | + logits = self.lm_head(hidden_states) |
| 290 | + |
| 291 | + return Qwen2VLCausalLMOutputWithPast( |
| 292 | + loss=None, |
| 293 | + logits=logits, |
| 294 | + past_key_values=None, |
| 295 | + hidden_states=None, |
| 296 | + attentions=None, |
| 297 | + rope_deltas=None, |
| 298 | + ) |
| 299 | + |
| 300 | + |
| 301 | +def qwen2_vl_base_forward_new( |
| 302 | + self: "Qwen2VLModel", |
| 303 | + input_ids: torch.LongTensor, |
| 304 | + attention_mask: Optional[torch.Tensor] = None, |
| 305 | + position_ids: Optional[torch.LongTensor] = None, |
| 306 | + labels: Optional[torch.LongTensor] = None, |
| 307 | + pixel_values: Optional[torch.FloatTensor] = None, |
| 308 | + pixel_values_videos: Optional[torch.FloatTensor] = None, |
| 309 | + image_grid_thw: Optional[torch.LongTensor] = None, |
| 310 | + video_grid_thw: Optional[torch.LongTensor] = None, |
| 311 | + **kwargs, |
| 312 | +): |
| 313 | + inputs_embeds, attention_mask = _get_input_embeds( |
| 314 | + self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw |
| 315 | + ) |
| 316 | + outputs = self.language_model( |
| 317 | + input_ids=None, |
| 318 | + position_ids=position_ids, |
| 319 | + attention_mask=attention_mask, |
| 320 | + past_key_values=None, |
| 321 | + inputs_embeds=inputs_embeds, |
| 322 | + use_cache=False, |
| 323 | + output_attentions=False, |
| 324 | + output_hidden_states=False, |
| 325 | + return_dict=True, |
| 326 | + cache_position=None, |
| 327 | + ) |
| 328 | + |
| 329 | + output = Qwen2VLModelOutputWithPast( |
| 330 | + last_hidden_state=outputs.last_hidden_state, |
| 331 | + past_key_values=outputs.past_key_values, |
| 332 | + hidden_states=outputs.hidden_states, |
| 333 | + attentions=outputs.attentions, |
| 334 | + rope_deltas=None, |
| 335 | + ) |
| 336 | + return output |
| 337 | + |
| 338 | + |
| 339 | +def qwen2_vl_forward_new( |
| 340 | + self: "Qwen2VLForConditionalGeneration", |
| 341 | + input_ids: torch.LongTensor, |
| 342 | + attention_mask: Optional[torch.Tensor] = None, |
| 343 | + position_ids: Optional[torch.LongTensor] = None, |
| 344 | + labels: Optional[torch.LongTensor] = None, |
| 345 | + pixel_values: Optional[torch.FloatTensor] = None, |
| 346 | + pixel_values_videos: Optional[torch.FloatTensor] = None, |
| 347 | + image_grid_thw: Optional[torch.LongTensor] = None, |
| 348 | + video_grid_thw: Optional[torch.LongTensor] = None, |
| 349 | + **kwargs, |
| 350 | +) -> "Qwen2VLCausalLMOutputWithPast": |
| 351 | + outputs = self.model( |
| 352 | + input_ids=input_ids, |
| 353 | + pixel_values=pixel_values, |
| 354 | + pixel_values_videos=pixel_values_videos, |
| 355 | + image_grid_thw=image_grid_thw, |
| 356 | + video_grid_thw=video_grid_thw, |
| 357 | + position_ids=position_ids, |
| 358 | + attention_mask=attention_mask, |
| 359 | + past_key_values=None, |
| 360 | + inputs_embeds=None, |
| 361 | + use_cache=False, |
| 362 | + output_attentions=False, |
| 363 | + output_hidden_states=False, |
| 364 | + return_dict=True, |
| 365 | + cache_position=None, |
| 366 | + ) |
| 367 | + hidden_states = outputs[0] |
| 368 | + logits = self.lm_head(hidden_states) |
| 369 | + |
| 370 | + return Qwen2VLCausalLMOutputWithPast( |
| 371 | + loss=None, |
| 372 | + logits=logits, |
| 373 | + past_key_values=None, |
| 374 | + hidden_states=None, |
| 375 | + attentions=None, |
| 376 | + rope_deltas=None, |
| 377 | + ) |
0 commit comments