diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index e967e9d22e..422f1694e4 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -155,6 +155,7 @@ Here is the list of the supported architectures : - TrOCR - UniSpeech - UniSpeech-SAT +- VideoChat-Flash-Qwen - Vision Encoder Decoder - ViT - Wav2Vec2 diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 0624624a77..331a8e2bfe 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -200,6 +200,8 @@ Qwen3VLVisionEmbMergerPatcher, QwenModelPatcher, SanaTextEncoderModelPatcher, + VideoChatFlashQwenLanguageModelPatcher, + VideoChatFlashQwenVisionEmbeddingModelPatcher, XverseModelPatcher, Zamba2ModelPatcher, ) @@ -5303,6 +5305,209 @@ class SiglipTextOpenVINOConfig(SiglipTextOnnxConfig): pass +class DummyVideoChatFlashQwenInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("hidden_states", "rotary_pos_emb") + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + visual_seq_length: int = DEFAULT_DUMMY_SHAPES["visual_seq_length"], + **kwargs, + ): + super().__init__(task, normalized_config, batch_size, num_channels, width, height, visual_seq_length, **kwargs) + self.num_frames = getattr(normalized_config.config, "mm_local_num_frames", 4) + self.embed_dim = getattr(normalized_config.config, "mm_hidden_size", 1408) + # Then input image size and patch size for the vision encoder can not be got from the config, we set them to fixed values according to the original implementation. + self.height = 224 + self.width = 224 + self.image_size = (self.height, self.width) + self.patch_size = 14 + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "hidden_states": + return self.random_float_tensor( + shape=[ + self.batch_size, + self.num_channels, + self.num_frames, + self.height, + self.width, + ], + framework=framework, + dtype=float_dtype, + ) + elif input_name == "rotary_pos_emb": + grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size + grid_t = self.num_frames + return self.random_float_tensor( + [1, 1 + grid_h * grid_t * grid_w, self.embed_dim], framework=framework, dtype=float_dtype + ) + + +class DummyVideoChatFlashQwenProjectorInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ["input"] + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + self.task = task + self.batch_size = batch_size + self.hidden_size = normalized_config.config.mm_hidden_size + # The original implementation with projector_type 'tome16_mlp_hd64' uses a fixed number of patches (64). + self.num_patches = 64 + self.normalized_config = normalized_config + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): + shape = [self.batch_size, self.num_patches, self.hidden_size] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +class VideoChatFlashQwenProjectorOpenVINOConfig(OnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenProjectorInputGenerator,) + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"input": {0: "batch_size", 1: "num_patches", 2: "hidden_size"}} + + +class VideoChatFlashQwenConfigBehavior(str, enum.Enum): + LANGUAGE = "language" + VISION_EMBEDDINGS = "vision_embeddings" + VISION_PROJECTION = "vision_projection" + TEXT_EMBEDDINGS = "text_embeddings" + + +@register_in_tasks_manager("videochat_flash_qwen", *["image-text-to-text"], library_name="transformers") +class VideoChatFlashQwenOpenVINOConfig(BaseVLMOpenVINOConfig): + MIN_TRANSFORMERS_VERSION = "4.49.0" + MAX_TRANSFORMERS_VERSION = "4.57.99" + SUPPORTED_BEHAVIORS = [model_type.value for model_type in VideoChatFlashQwenConfigBehavior] + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenInputGenerator,) + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: VideoChatFlashQwenConfigBehavior = VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + **kwargs, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + behavior=behavior, + preprocessors=preprocessors, + ) + self._orig_config = config + if self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if not self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return {} + return { + "hidden_states": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + # rotary_pos_emb has a fixed leading dimension of 1 in the dummy generator, + # so we do not associate axis 0 with batch_size and keep only dynamic axes here. + "rotary_pos_emb": {1: "num_tokens", 2: "hidden_size"}, + } + + def with_behavior( + self, + behavior: Union[str, VideoChatFlashQwenConfigBehavior], + ): + """ + Creates a config for different behaviour. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, VideoChatFlashQwenConfigBehavior): + behavior = VideoChatFlashQwenConfigBehavior(behavior) + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_PROJECTION: + export_config = VideoChatFlashQwenProjectorOpenVINOConfig( + self._orig_config, + task="feature-extraction", + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + return export_config + + if behavior == VideoChatFlashQwenConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) + + if behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def get_model_for_behavior(self, model, behavior: Union[str, VideoChatFlashQwenConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, VideoChatFlashQwenConfigBehavior): + behavior = VideoChatFlashQwenConfigBehavior(behavior) + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_PROJECTION: + vision_projector = model.get_model().mm_projector.mlp + vision_projector.config = model.config + return vision_projector + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + vision_tower = model.get_vision_tower().vision_tower + vision_tower.config = model.config + return vision_tower + + if behavior == VideoChatFlashQwenConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.get_input_embeddings() + text_embedding.config = model.config + return text_embedding + + if behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE: + model.model.llm_compress_layer_list = [] + return model.language_model if not hasattr(model, "lm_head") else model + + def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None): + model_kwargs = model_kwargs or {} + if self._behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE: + return VideoChatFlashQwenLanguageModelPatcher(self, model, model_kwargs) + + if self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return VideoChatFlashQwenVisionEmbeddingModelPatcher(self, model, model_kwargs) + + return super().patch_model_for_export(model, model_kwargs) + + @register_in_tasks_manager( "hunyuan_v1_dense", *[ diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 32dd2d6c6d..de0eae6355 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7640,6 +7640,76 @@ def __exit__(self, exc_type, exc_value, traceback): del afmoe_moe.down_projs, afmoe_moe.gate_projs, afmoe_moe.up_projs +class VideoChatFlashQwenVisionEmbeddingModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + def forward_wrap(self, hidden_states, rotary_pos_emb): + hidden_states = self.patch_embed(hidden_states.type(self.dtype)) + B, T, L, C = hidden_states.shape # T: temporal; L: spatial + hidden_states = hidden_states.view([B, T * L, C]) + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + hidden_states = torch.cat((cls_tokens, hidden_states), dim=1) + hidden_states = hidden_states + rotary_pos_emb + hidden_states = hidden_states.reshape(B, -1, C) + + for idx, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, residual=None) + + return hidden_states + + model.forward = types.MethodType(forward_wrap, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + +class VideoChatFlashQwenLanguageModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + def forward_wrap( + self, + attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + ): + outputs, _ = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + output = (logits,) + outputs[1:] + return output + + model.forward = types.MethodType(forward_wrap, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + # adopted from https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/models/llama/modeling_llama.py#L197 class LlamaEagle3Attention(LlamaAttention): """ diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index af2f1edaba..525b99403f 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -303,6 +303,7 @@ def get_submodels(model): "phi4_multimodal", "llama4", "minicpmo", + "videochat_flash_qwen", ] SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"] diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index f83b163cdd..c6b6941ae2 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -428,6 +428,18 @@ class OVQuantizationMethod(str, Enum): "weight_only": True, }, }, + "OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B": { + "quantization_configs": { + "lm_model": { + "bits": 4, + "sym": False, + "group_size": 128, + "ratio": 1.0, + }, + "text_embeddings_model": {"bits": 8, "sym": True, "weight_only": True}, + "vision_embeddings_model": {"bits": 8, "sym": True, "weight_only": True}, + }, + }, "qnguyen3/nanoLLaVA": { "bits": 4, "sym": False, diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index beb7b974eb..2524c76fcb 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -4,11 +4,12 @@ import logging import math import os +import re import warnings from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino @@ -4802,6 +4803,955 @@ def preprocess_inputs( return inputs +class _OVVideoChatFlashQwenForCausalLM(OVModelForVisualCausalLM): + from transformers import AutoModel + + auto_model_class = AutoModel + additional_parts = ["vision_projection"] + IMAGE_TOKEN_INDEX = -200 + IGNORE_INDEX = -100 + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/vision_tower_builder.py#L181 + def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): + """ + grid_size: int of the grid height and width + t_size: int of the temporal size + return: + pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + assert embed_dim % 4 == 0 + embed_dim_spatial = embed_dim // 4 * 3 + embed_dim_temporal = embed_dim // 4 + + # spatial + grid_h = torch.arange(grid_size, dtype=torch.float32) + grid_w = torch.arange(grid_size, dtype=torch.float32) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape(2, 1, grid_size, grid_size) + pos_embed_spatial = _OVVideoChatFlashQwenForCausalLM.get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # temporal + grid_t = torch.arange(t_size, dtype=torch.float32) + pos_embed_temporal = _OVVideoChatFlashQwenForCausalLM.get_1d_sincos_pos_embed_from_grid( + embed_dim_temporal, grid_t + ) + + # concate: [T, H, W] order + pos_embed_temporal = pos_embed_temporal[:, None, :] + pos_embed_temporal = pos_embed_temporal.repeat(1, grid_size**2, 1) # [T, H*W, D // 4] + pos_embed_spatial = pos_embed_spatial[None, :, :] + pos_embed_spatial = pos_embed_spatial.repeat(t_size, 1, 1) # [T, H*W, D // 4 * 3] + + pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1) + pos_embed = pos_embed.reshape(-1, embed_dim) # [T*H*W, D] + + if cls_token: + pos_embed = torch.cat([torch.zeros((1, embed_dim), dtype=pos_embed.dtype), pos_embed], dim=0) + return pos_embed + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/vision_tower_builder.py#L141 + def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + grid = grid if isinstance(grid, torch.Tensor) else torch.as_tensor(grid, dtype=torch.float32) + + # use half of dimensions to encode grid_h + emb_h = _OVVideoChatFlashQwenForCausalLM.get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0] + ) # (H*W, D/2) + emb_w = _OVVideoChatFlashQwenForCausalLM.get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1] + ) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/vision_tower_builder.py#L156 + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / (10000**omega) # (D/2,) + + pos = pos if isinstance(pos, torch.Tensor) else torch.as_tensor(pos, dtype=torch.float32) + pos = pos.reshape(-1).to(dtype=torch.float32) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + def __init__( + self, + language_model: ov.Model, + text_embeddings: ov.Model, + vision_embeddings: ov.Model, + config: PretrainedConfig = None, + device: str = "CPU", + dynamic_shapes: bool = None, + ov_config: Optional[Dict[str, str]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + from torch import nn + + super().__init__( + language_model=language_model, + text_embeddings=text_embeddings, + vision_embeddings=vision_embeddings, + config=config, + device=device, + dynamic_shapes=dynamic_shapes, + ov_config=ov_config, + model_save_dir=model_save_dir, + quantization_config=quantization_config, + **kwargs, + ) + num_frames = getattr(config, "mm_local_num_frames", 8) + self.num_attention_heads = 16 + self.patch_size = 14 + self.image_size = 224 + self.grid_size = ( + num_frames, + self.image_size // self.patch_size, + self.image_size // self.patch_size, + ) # (T, H, W) + self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + self.num_img_patches = self.grid_size[1] * self.grid_size[2] + self.embed_dim = getattr(config, "mm_hidden_size", 1408) + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim)) + self.img_pos_embed = nn.Parameter(torch.zeros(1, self.num_img_patches + 1, self.embed_dim)) + pos_embed = _OVVideoChatFlashQwenForCausalLM.get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.grid_size[1], self.grid_size[0], cls_token=True + ) + self.pos_embed.data.copy_(pos_embed.to(dtype=self.pos_embed.dtype).unsqueeze(0)) + + img_pos_embed = _OVVideoChatFlashQwenForCausalLM.get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.grid_size[1], 1, cls_token=True + ) + self.img_pos_embed.data.copy_(img_pos_embed.to(dtype=self.img_pos_embed.dtype).unsqueeze(0)) + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/mm_projector_builder.py#L6 + def bipartite_soft_matching( + metric: torch.Tensor, + r: int, + ) -> Tuple[Callable, Callable]: + """ + Build balanced ToMe token matching operators for vision token compression. + In this model's vision path, it is the core matching step used by + ``merge_tokens`` to progressively shrink visual token sequences before + ``vision_projection``. This reduces the token count passed into the + language-model side of the multimodal pipeline, improving memory/latency + while keeping high-similarity visual information aggregated. + + This function splits tokens into two interleaved groups (even/odd positions), + computes pairwise similarity between the two groups, and selects the top-``r`` + pairs to merge. It returns two closures: + + - ``merge``: merges matched source tokens into destination tokens to reduce + sequence length while preserving information. + - ``unmerge``: restores merged tokens back to the original token layout, + which is useful for shape recovery or downstream alignment + + Args: + metric (`torch.Tensor`): Token features with shape ``[batch, tokens, channels]`` + used to compute matching similarity. + r (`int`): Number of tokens to remove by merging. It is internally capped + at half of available tokens. + + Returns: + `Tuple[Callable, Callable]`: ``(merge, unmerge)`` operators for reversible + token reduction. + """ + protected = 0 + + t = metric.shape[1] + r = min(r, (t - protected) // 2) + + assert r > 0, r + + with torch.no_grad(): + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = metric[..., ::2, :], metric[..., 1::2, :] + scores = a @ b.transpose(-1, -2) + + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = x[..., ::2, :], x[..., 1::2, :] + n, t1, c = src.shape + unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_add(-2, dst_idx.expand(n, r, c), src) # , reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + n, _, c = unm.shape + + src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) + + out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype) + + out[..., 1::2, :] = dst + out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) + out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) + + return out + + return merge, unmerge + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/mm_projector_builder.py#L62 + def merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the merge function by taking a weighted average based on token size. + Returns the merged tensor and the new token sizes. + """ + if size is None: + size = torch.ones_like(x[..., 0, None]) + + x = merge(x * size, mode="sum") + size = merge(size, mode="sum") + + x = x / size + return x, size + + def get_vision_embeddings(self, images): + # Upstream preprocessing provides BTCHW, but the vision tower expects BCHWT, + # so we permute dimensions before running the visual encoder. + # We then keep patch tokens in [B, T*L, C] (dropping cls later) because + # downstream token merging/projection operates on a flattened token sequence. + T = images.shape[1] + images = images.permute(0, 2, 1, 3, 4) + if T == 1: + pos_embeds = self.img_pos_embed.detach() + else: + pos_embeds = self.pos_embed.detach() + image_embeds = self.vision_embeddings(images, rotary_pos_emb=pos_embeds).last_hidden_state + image_embeds = image_embeds[:, 1:, :] + + videos_features = torch.from_numpy(image_embeds) if isinstance(image_embeds, np.ndarray) else image_embeds + + return videos_features + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/mm_projector_builder.py#L96 + def merge_tokens(self, x, target_num_token): + r""" + x = torch.randn(10, 2560, c) + x = merge_tokens(x, r_merge_list=[1280]) + """ + size = None + b, p, c = x.shape + current_num_tokens = p + # Number of tokens to merge at each iterative ToMe step until reaching target_num_token. + r_merge_list = [] + assert current_num_tokens > target_num_token, f"{current_num_tokens} should greater than {target_num_token}" + while current_num_tokens != target_num_token: + if current_num_tokens - target_num_token <= (current_num_tokens // 2): + r_merge_list.append(current_num_tokens - target_num_token) + break + else: + r_merge_list.append(current_num_tokens // 2) + current_num_tokens = current_num_tokens - (current_num_tokens // 2) + + head = self.num_attention_heads + + dim = c // head + for r in r_merge_list: + metric = x.reshape(b, p, head, dim).mean(2) # [b, p, c//head] + merge, _ = _OVVideoChatFlashQwenForCausalLM.bipartite_soft_matching(metric, r) + x, size = _OVVideoChatFlashQwenForCausalLM.merge_wavg(merge, x, size) + _, p, _ = x.shape + + return x + + def get_vision_projection(self, x, compress=False, local_num_frames=-1): + height = width = self.image_size // self.patch_size + assert height * width == x.shape[1] + + if local_num_frames != -1 and local_num_frames != 1: + assert compress is True + if compress: + if local_num_frames != -1: + num_frames = local_num_frames + x = x.reshape(x.shape[0] // local_num_frames, -1, x.shape[-1]) + else: + num_frames = x.shape[0] + x = x.reshape(1, -1, x.shape[-1]) + num_tome_tokens = 16 * num_frames + else: + num_tome_tokens = 64 + + x = self.merge_tokens(x, target_num_token=num_tome_tokens) + x = self.vision_projection(x) + x = torch.from_numpy(x) if isinstance(x, np.ndarray) else x + return x + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/mm_utils.py#L797 + def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors == "pt": + return torch.tensor(input_ids, dtype=torch.long) + return input_ids + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/vision_tower_builder.py#L681 + + def image_preprocess(images, return_tensors, target_size=None): + from functools import partial, reduce + + from PIL.Image import Image as PILImage + from transformers.image_processing_utils import BatchFeature + from transformers.image_transforms import ( + convert_to_rgb, + normalize, + rescale, + resize, + to_channel_dimension_format, + ) + from transformers.image_utils import ChannelDimension, PILImageResampling, to_numpy_array + + if isinstance(images, PILImage): + images = [images] + else: + # to adapt video data + images = [to_numpy_array(image) for image in images] + assert isinstance(images, list) + + if target_size is None: + target_size = (224, 224) + + data_format = ChannelDimension.FIRST + rescale_factor = 1 / 255 + image_mean = (0.485, 0.456, 0.406) + image_std = (0.229, 0.224, 0.225) + + transforms = [ + convert_to_rgb, + to_numpy_array, + partial(resize, size=target_size, resample=PILImageResampling.BICUBIC, data_format=data_format), + partial(rescale, scale=rescale_factor, data_format=data_format), + partial(normalize, mean=image_mean, std=image_std, data_format=data_format), + partial(to_channel_dimension_format, channel_dim=data_format, input_channel_dim=data_format), + ] + + images = reduce(lambda x, f: [*map(f, x)], transforms, images) + data = {"pixel_values": images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + video: Optional["VideoInput"] = None, + audio: Optional[np.ndarray] = None, + ): + if audio is not None: + raise ValueError("Audio input is not supported") + if tokenizer is None: + raise ValueError("Tokenizer is required.") + image_sizes = [] + frames = [] + results = {} + local_num_frames = getattr(config, "mm_local_num_frames", 4) + + # preprocess text + prompt = f"\n{text}" if (image is not None or video is not None) else text + if getattr(tokenizer, "chat_template", None) is not None: + messages = [{"role": "user", "content": prompt}] + text_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + text_prompt = prompt + input_ids = _OVVideoChatFlashQwenForCausalLM.tokenizer_image_token( + text_prompt, tokenizer, _OVVideoChatFlashQwenForCausalLM.IMAGE_TOKEN_INDEX, return_tensors="pt" + ).unsqueeze(0) + results["input_ids"] = input_ids + + # preprocess video + if video is not None: + if isinstance(video, np.ndarray): + num_frames = video.shape[0] + image_size = video.shape[1:3] + if num_frames % local_num_frames != 0: + pad_frames = local_num_frames - (num_frames % local_num_frames) + pad = np.repeat(video[-1:], pad_frames, axis=0) + video = np.concatenate([video, pad], axis=0) + elif isinstance(video, list): + num_frames = len(video) + if isinstance(video[0], np.ndarray): + image_size = video[0].shape[:2] + else: + width, height = video[0].size + image_size = (height, width) + if num_frames % local_num_frames != 0: + pad_frames = local_num_frames - (num_frames % local_num_frames) + video = video + [video[-1]] * pad_frames + else: + raise ValueError("Unsupported video type: {}".format(type(video))) + + image_sizes.append(image_size) + if processor is not None: + processed_images = processor(images=video, return_tensors="pt") + else: + processed_images = _OVVideoChatFlashQwenForCausalLM.image_preprocess( + images=video, return_tensors="pt" + )["pixel_values"] + frames.append(processed_images) + + # preprocess image + if image is not None: + from PIL.Image import Image as PILImage + + if isinstance(image, PILImage): + width, height = image.size + image_size = (height, width) + else: + image_size = image.shape[:2] + if processor is not None: + image_frame = processor(images=image, return_tensors="pt") + else: + image_frame = _OVVideoChatFlashQwenForCausalLM.image_preprocess(images=image, return_tensors="pt")[ + "pixel_values" + ] + frames.append(image_frame) + image_sizes.append(image_size) + + if len(frames) >= 1: + results["images"] = frames + results["image_sizes"] = image_sizes + + if tokenizer.pad_token_id is None: + if "qwen" in tokenizer.name_or_path.lower(): + logger.info("Setting pad token to bos token for qwen model.") + tokenizer.pad_token_id = 151643 + attention_masks = input_ids.ne(tokenizer.pad_token_id).long() + results["attention_mask"] = attention_masks + + return results + + def encode_video_image(self, images_list, video_idx_in_batch): + # process the video encoder output using image connector + bs = len(images_list) + + concat_images = [] + concat_videos = [] + for idx, image in enumerate(images_list): + if idx in video_idx_in_batch: + concat_videos.append(image) + else: + concat_images.append(image) + # print(concat_videos[0].shape) + has_image = len(concat_images) > 0 + has_video = len(concat_videos) > 0 + + mm_local_num_frames = getattr(self.config, "mm_local_num_frames", -1) + assert mm_local_num_frames != -1 + if has_image: + image_split_sizes = [image.shape[0] for image in concat_images] + concat_images = torch.cat([image.unsqueeze(1) for image in concat_images], dim=0) + # print("input vit image.shape:", concat_images.shape) + images_features = self.get_vision_embeddings(concat_images) # B_i, N, D + images_features = torch.split(images_features, image_split_sizes) + + if has_video: + video_split_sizes = [video.shape[0] // mm_local_num_frames for video in concat_videos] + concat_videos = torch.cat( + [ + video.reshape( + video.shape[0] // mm_local_num_frames, + mm_local_num_frames, + video.shape[1], + video.shape[2], + video.shape[3], + ) + for video in concat_videos + ], + dim=0, + ) + # print("input vit video.shape:", concat_videos.shape) + videos_features = self.get_vision_embeddings(concat_videos) # B_v, N, D + videos_features = [ + v.reshape(-1, v.shape[-2] // mm_local_num_frames, v.shape[-1]) + for v in torch.split(videos_features, video_split_sizes) + ] + + all_videos_or_images_features = [] + img_idx = 0 + vid_idx = 0 + + for idx in range(bs): + if idx in video_idx_in_batch: + feat = self.get_vision_projection( + videos_features[vid_idx], compress=True, local_num_frames=mm_local_num_frames + ) + vid_idx += 1 + else: + feat = self.get_vision_projection(images_features[img_idx], compress=False) + img_idx += 1 + # print("video_idx_in_batch:", video_idx_in_batch) + all_videos_or_images_features.append(feat) + + if has_video: + assert vid_idx == len(videos_features), f"vid: {vid_idx} != {len(videos_features)}" + if has_image: + assert img_idx == len(images_features), f"img: {img_idx} != {len(images_features)}" + + return all_videos_or_images_features + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/mm_utils.py#L502-L537 + def select_best_resolution(original_size, possible_resolutions, max_resolutions, patch_size): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in possible_resolutions: + if max_resolutions is not None and (width * height != patch_size * patch_size): + if (width * height + patch_size * patch_size) > max_resolutions: + continue + # Calculate the downscaled size to keep the aspect ratio + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + + # Calculate effective and wasted resolutions + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + # print(f"original_size={original_size}, possible_resolutions={possible_resolutions}, max_resolutions={max_resolutions}, best_fit={best_fit}") + if best_fit is None: + raise ValueError(f"Can't find suitable fit in {possible_resolutions} at max:{max_resolutions}") + return best_fit + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/mm_utils.py#L601-L631 + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size, max_resolutions=None): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if isinstance(grid_pinpoints, list): + possible_resolutions = grid_pinpoints + else: + pairs = re.findall(r"\(\s*(\d+)\s*,\s*(\d+)\s*\)", grid_pinpoints) + possible_resolutions = [(int(w), int(h)) for w, h in pairs] + width, height = _OVVideoChatFlashQwenForCausalLM.select_best_resolution( + image_size, possible_resolutions, max_resolutions=max_resolutions, patch_size=patch_size + ) + + return width // patch_size, height // patch_size + + def get_text_embeddings(self, input_ids): + squeeze_batch_dim = False + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + squeeze_batch_dim = True + + text_embed = super().get_text_embeddings(input_ids) + + if squeeze_batch_dim and text_embed.ndim > 0 and text_embed.shape[0] == 1: + text_embed = text_embed[0] + + text_embed = torch.from_numpy(text_embed) if isinstance(text_embed, np.ndarray) else text_embed + return text_embed + + # Adopted from https://huggingface.co/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B/blob/main/modeling_videochat_flash.py#L183 + def get_multimodal_embeddings( + self, + input_ids, + pixel_values=None, + attention_mask=None, + position_ids=None, + modalities=None, + image_sizes=None, + **kwargs, + ): + images = pixel_values + + if images is None: + inputs_embeds = self.get_text_embeddings(input_ids) + return inputs_embeds, attention_mask, position_ids + + # rank_print(modalities) + if isinstance(images, list): + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + if modalities is None: + modalities = [] + for image in images: + if image.shape[0] > 1: + modalities.append("video") + else: + modalities.append("image") + + if modalities is None: + modalities = ["image"] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr(self.config, "mm_newline_position", "nothing") + + # video backbone, process video with compress + image_features = self.encode_video_image(images_list, video_idx_in_batch=video_idx_in_batch) + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_idx in video_idx_in_batch: # video operations + frame_feature = image_feature + + if "pad" in mm_patch_merge_type: + if mm_newline_position == "one_token": + frame_feature = frame_feature.flatten(0, 1) + if "unpad" in mm_patch_merge_type: + frame_feature = torch.cat( + (frame_feature, self.model.image_newline[None].to(frame_feature.device)), dim=0 + ) + else: + frame_feature = torch.cat( + (frame_feature, self.model.frame_newline[None].to(frame_feature.device)), dim=0 + ) + elif mm_newline_position == "nothing": + frame_feature = frame_feature.flatten(0, 1) + else: + frame_feature = frame_feature.flatten(0, 1) + + # print(f"final video frame_feature.shape: {frame_feature.shape}") + image_feature = frame_feature + + elif image_feature.shape[0] > 1: # multi patches and multi images operations + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + height = width = 8 + assert ( + height * width == base_image_feature.shape[0] + ), f"height:{height}, width: {width}, base_image_feature: {base_image_feature.shape}" + + if "anyres" in image_aspect_ratio: + vision_tower_image_size = 224 + ( + num_patch_width, + num_patch_height, + ) = _OVVideoChatFlashQwenForCausalLM.get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + vision_tower_image_size, + max_resolutions=None, + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + + input_ids = [ + cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + + new_input_embeds = [] + cur_image_idx = 0 + + mm_llm_compress = getattr(self.config, "mm_llm_compress", False) + + if mm_llm_compress: + self.language_model.model.llm_compress_type = getattr(self.config, "llm_compress_type", "attention") + self.language_model.model.llm_compress_layer_list = getattr( + self.config, "llm_compress_layer_list", [8, 16, 24] + ) + self.language_model.model.llm_image_token_ratio_list = getattr( + self.config, "llm_image_token_ratio_list", [1.0, 0.5, 0.25, 0.125] + ) + first_image_token_position = [] + text_prompt_lens = [] + else: + self.language_model.model.llm_compress_type = "attention" + self.language_model.model.llm_compress_layer_list = [] + self.language_model.model.llm_image_token_ratio_list = [] + first_image_token_position = [] + text_prompt_lens = [] + + # rank_print("Inserting Images embedding") + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == _OVVideoChatFlashQwenForCausalLM.IMAGE_TOKEN_INDEX).sum() + + if mm_llm_compress: + ####### copy from pdrop, only support single image/video NOTE ################## + # record image position for further dropping + image_index = torch.where(cur_input_ids == _OVVideoChatFlashQwenForCausalLM.IMAGE_TOKEN_INDEX)[ + 0 + ].tolist() + assert len(image_index) == 1, f"Only support singe/video: {image_index}" + if image_index == []: + first_image_token_position.append(-1) + else: + first_image_token_position.append(image_index[0]) + + # record input instruction length in inference mode + if not self.training: + if image_index == []: + assert num_images == 0, num_images + else: + assert num_images == 1, f"num_images={num_images}" + text_prompt_lens.append(cur_input_ids.shape[0] - num_images) # consider image place holder + + ############################################### + + # print(f"num_images={num_images}") + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_text_embeddings(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + cur_image_idx += 1 + continue + + image_token_indices = ( + [-1] + + torch.where(cur_input_ids == _OVVideoChatFlashQwenForCausalLM.IMAGE_TOKEN_INDEX)[0].tolist() + + [cur_input_ids.shape[0]] + ) + cur_input_ids_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_input_ids_noim] + cur_input_embeds = self.get_text_embeddings(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + if i < num_images: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + print(f"cur_image_idx={cur_image_idx} is not ok") + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + # import pdb; pdb.set_trace() + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + + new_input_embeds.append(cur_new_input_embeds) + + if mm_llm_compress: + self.language_model.model.first_image_token_position = first_image_token_position + self.language_model.model.text_prompt_lens = text_prompt_lens + self.language_model.model.num_image_token_lens = [ + image_feature.shape[0] for image_feature in image_features + ] + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + # rank_print("Finishing Inserting") + + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + # print("Prepare pos id") + + for i, cur_new_embed in enumerate(new_input_embeds): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return new_input_embeds, attention_mask, position_ids + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + model_kwargs.pop("images", None) + model_kwargs.pop("image_sizes", None) + past_len = self.language_model._past_length + attn = model_kwargs.get("attention_mask") + if attn is not None and attn.shape[1] < past_len + 1: + model_kwargs["attention_mask"] = torch.ones( + (attn.shape[0], past_len + 1), + dtype=attn.dtype, + device=attn.device, + ) + + return model_kwargs + + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, @@ -4824,4 +5774,5 @@ def preprocess_inputs( "llama4": _OVLlama4ForCausalLM, "qwen3_vl": _OVQwen3VLForCausalLM, "minicpmo": _OVMiniCPMOForCausalLM, + "videochat_flash_qwen": _OVVideoChatFlashQwenForCausalLM, } diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 3741725e5d..dad2ec70b6 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -1472,6 +1472,13 @@ def _quantize_ovbasemodel( ov_model_name, pipeline_quantization_config.default_config ) if config is None: + if immediate_save: + # The submodels being quantized is unloaded after quantization, + # so the skipped submodels should also be unloaded to avoid keeping their IR files open on Windows. + # This can avoid later _merge_move failures caused by locked .bin files. + ov_model = self.model.ov_models[ov_model_name] + if ov_model is not None: + self.model._unload_ov_model(ov_model) continue ov_model = self.model.ov_models[ov_model_name] nncf_dataset = calibration_dataset.get(ov_model_name, None) if calibration_dataset else None diff --git a/setup.py b/setup.py index 29ad551373..3d10455ea5 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,9 @@ "vocos", "vector_quantize_pytorch", "openvino-genai", + "av", + "decord", + "imageio", ] QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"] diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 9519cea1ec..38dc7dedfb 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -99,7 +99,9 @@ class ExportModelTest(unittest.TestCase): SUPPORTED_ARCHITECTURES.update({"cohere2": OVModelForCausalLM}) if is_transformers_version(">=", "4.49"): - SUPPORTED_ARCHITECTURES.update({"zamba2": OVModelForCausalLM}) + SUPPORTED_ARCHITECTURES.update( + {"zamba2": OVModelForCausalLM, "videochat_flash_qwen": OVModelForVisualCausalLM} + ) if is_transformers_version(">=", "4.53.0"): SUPPORTED_ARCHITECTURES.update({"granitemoehybrid": OVModelForCausalLM}) @@ -146,7 +148,7 @@ def _openvino_export( model_class = TasksManager.get_model_class_for_task(task, library=library_name) model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True) TasksManager.standardize_model_attributes(model_name, model, library_name=library_name) - elif model_type == "llava": + elif model_type in ["llava", "videochat_flash_qwen"]: model = MODEL_TYPE_TO_CLS_MAPPING[model_type].auto_model_class.from_pretrained( model_name, **loading_kwargs ) diff --git a/tests/openvino/test_seq2seq.py b/tests/openvino/test_seq2seq.py index 73e12b5584..9f32f0aa96 100644 --- a/tests/openvino/test_seq2seq.py +++ b/tests/openvino/test_seq2seq.py @@ -556,10 +556,22 @@ class OVModelForVisualCausalLMIntegrationTest(OVSeq2SeqTestMixin): SUPPORT_VIDEO += ["qwen3_vl"] if is_transformers_version(">=", "4.54.0"): + # the layers attribute of DynamicCache is used in videochat_flash_qwen model + SUPPORTED_ARCHITECTURES.append("videochat_flash_qwen") + SUPPORT_VIDEO.append("videochat_flash_qwen") # remote code models differs after transformers v4.54 SUPPORTED_ARCHITECTURES = set(SUPPORTED_ARCHITECTURES) - {"llava-qwen2", "phi3_v", "phi4mm"} - REMOTE_CODE_MODELS = ["internvl_chat", "minicpmv", "minicpmo", "llava-qwen2", "phi3_v", "maira2", "phi4mm"] + REMOTE_CODE_MODELS = [ + "internvl_chat", + "minicpmv", + "minicpmo", + "llava-qwen2", + "phi3_v", + "maira2", + "phi4mm", + "videochat_flash_qwen", + ] IMAGE = Image.open( requests.get( TEST_IMAGE_URL, @@ -600,6 +612,10 @@ def get_transformer_model_class(self, model_arch): from transformers import Qwen2VLForConditionalGeneration return Qwen2VLForConditionalGeneration + if model_arch == "videochat_flash_qwen": + from transformers import AutoModel + + return AutoModel return AutoModelForCausalLM def _check_device_and_request(self, ov_model, expected_device, has_request): @@ -630,11 +646,19 @@ def test_find_untested_architectures(self): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): - def compare_outputs(inputs, ov_model, transformers_model, generation_config): + def compare_outputs(inputs, ov_model, transformers_model, generation_config, has_image=False, has_video=True): transformers_inputs = copy.deepcopy(inputs) + if model_arch == "videochat_flash_qwen": + input_ids = transformers_inputs.pop("input_ids") + transformers_inputs["inputs"] = input_ids + transformers_inputs["modalities"] = [] + if has_video: + transformers_inputs["modalities"].append("video") + if has_image: + transformers_inputs["modalities"].append("image") ov_outputs = ov_model.generate(**inputs, generation_config=generation_config) # original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them - if model_arch in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch in ["minicpmv", "minicpmo", "internvl_chat", "videochat_flash_qwen"]: ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :] with torch.no_grad(): transformers_outputs = transformers_model.generate( @@ -699,6 +723,10 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config): from transformers.cache_utils import DynamicCache transformers_inputs["past_key_values"] = DynamicCache() + if model_arch == "videochat_flash_qwen": + input_ids = transformers_inputs.pop("input_ids") + transformers_inputs["inputs"] = input_ids + transformers_inputs["modalities"] = ["image"] test_device = "AUTO" ov_model.to(test_device) @@ -711,7 +739,7 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config): self._check_device_and_request(ov_model, test_device, False) # pytorch minicpmv and internvl_chat are not designed to be used via forward - if model_arch not in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch not in ["minicpmv", "minicpmo", "internvl_chat", "videochat_flash_qwen"]: set_seed(SEED) ov_outputs = ov_model(**inputs) set_seed(SEED) @@ -723,8 +751,10 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config): ) ov_model.generation_config.eos_token_id = None - transformers_model.generation_config.eos_token_id = None - transformers_model.generation_config.do_sample = False + # For videochat_flash_qwen, generation_config is None in transformers model, so we need to check it before setting eos_token_id + if transformers_model.generation_config is not None: + transformers_model.generation_config.eos_token_id = None + transformers_model.generation_config.do_sample = False ov_model.config.eos_token_id = None transformers_model.config.eos_token_id = None ov_model.generation_config.do_sample = False @@ -772,7 +802,7 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config): transformers_outputs = transformers_outputs[1].sequences # original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them - if model_arch in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch in ["minicpmv", "minicpmo", "internvl_chat", "videochat_flash_qwen"]: ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :] self.assertTrue( torch.equal(ov_outputs, transformers_outputs), @@ -799,7 +829,7 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config): # check video+image scenario inputs = ov_model.preprocess_inputs(**preprocessors, text=question, video=input_video, image=image) - compare_outputs(inputs, ov_model, transformers_model, gen_config) + compare_outputs(inputs, ov_model, transformers_model, gen_config, has_image=True) if model_arch in self.SUPPORT_AUDIO: input_audio = self._generate_random_audio_data() @@ -875,8 +905,9 @@ def test_generate_utils(self, model_arch): model_id, export=True, trust_remote_code=trust_remote_code, device=OPENVINO_DEVICE ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) - question = "Describe image" preprocessors = self.get_preprocessors(model_arch) + + question = "Describe image" inputs = model.preprocess_inputs(**preprocessors, text=question, image=self.IMAGE.resize((600, 600))) # General case outputs = model.generate(**inputs, max_new_tokens=10) @@ -948,7 +979,7 @@ def get_preprocessors(self, model_arch): model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS ) preprocessors = {"processor": processor, "tokenizer": tokenizer, "config": config} - elif model_arch == "internvl_chat": + elif model_arch in ["internvl_chat", "videochat_flash_qwen"]: tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS ) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index fe6d584d2f..c340a3949f 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -228,6 +228,7 @@ "ltx-video": "optimum-intel-internal-testing/tiny-random-ltx-video", "zamba2": "optimum-intel-internal-testing/tiny-random-zamba2", "qwen3_eagle3": "AngelSlim/Qwen3-1.7B_eagle3", + "videochat_flash_qwen": "optimum-intel-internal-testing/tiny-videochat-flash-qwen", } EAGLE3_MODELS = {"qwen3_eagle3": ("AngelSlim/Qwen3-1.7B_eagle3", "Qwen/Qwen3-1.7B")} @@ -399,6 +400,7 @@ "minicpm3", "deepseek", "qwen3_eagle3", + "videochat_flash_qwen", )