From c4eb6895c2c2b0ddd870832701e49c40a3794157 Mon Sep 17 00:00:00 2001 From: xufang Date: Mon, 9 Feb 2026 13:38:42 +0800 Subject: [PATCH 01/39] support videochat_flash_qwen --- optimum/exporters/openvino/model_configs.py | 188 +++++++++++++ optimum/exporters/openvino/model_patcher.py | 59 +++++ optimum/exporters/openvino/utils.py | 1 + .../openvino/modeling_visual_language.py | 250 ++++++++++++++++++ 4 files changed, 498 insertions(+) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index ca12d455be..6cb23fbadb 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -197,6 +197,8 @@ Qwen3MoeModelPatcher, QwenModelPatcher, SanaTextEncoderModelPatcher, + VideochatFlashQwenLanguageModelPatcher, + VideochatFlashQwenVisionEmbeddingModelPatcher, XverseModelPatcher, Zamba2ModelPatcher, ) @@ -4964,3 +4966,189 @@ class SiglipTextWithProjectionOpenVINOConfig(SiglipTextWithProjectionOnnxConfig) @register_in_tasks_manager("siglip-text", *["feature-extraction"]) class SiglipTextOpenVINOConfig(SiglipTextOnnxConfig): pass + + +class DummyVideoChatFlashQwenInputGenerator(DummyVisionInputGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + visual_seq_length: int = DEFAULT_DUMMY_SHAPES["visual_seq_length"], + **kwargs, + ): + super().__init__(task, normalized_config, batch_size, num_channels, width, height, visual_seq_length, **kwargs) + if hasattr(normalized_config, "config") and hasattr(normalized_config.config, "mm_local_num_frames"): + self.num_frames = normalized_config.config.mm_local_num_frames + self.height = 224 + self.width = 224 + self.image_size = (self.height, self.width) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "pixel_values": + return self.random_float_tensor( + shape=[ + self.batch_size, + self.num_channels, + self.num_frames, + self.height, + self.width, + ], + framework=framework, + dtype=float_dtype, + ) + + +class DummyVideoChatFlashQwenProjectorInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ["input"] + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + self.task = task + self.batch_size = batch_size + self.hidden_size = normalized_config.hidden_size + self.num_patches = 64 + self.normalized_config = normalized_config + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): + shape = [self.batch_size, self.num_patches, self.hidden_size] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +class VideoChatFlashQWENProjectorOpenVINOConfig(OnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenProjectorInputGenerator,) + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"input": {0: "batch_size", 1: "num_patches", 2: "hidden_size"}} + + +class VideoChatFlashQwenConfigBehavior(str, enum.Enum): + LANGUAGE = "language" + VISION_EMBEDDINGS = "vision_embeddings" + VISION_PROJECTION = "vision_projection" + TEXT_EMBEDDINGS = "text_embeddings" + + +@register_in_tasks_manager("videochat_flash_qwen", *["image-text-to-text"], library_name="transformers") +class VideoChatFlashQwenOpenVINOConfig(BaseVLMOpenVINOConfig): + MIN_TRANSFORMERS_VERSION = "4.42.0" + SUPPORTED_BEHAVIORS = [model_type.value for model_type in VideoChatFlashQwenConfigBehavior] + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenInputGenerator,) + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: VideoChatFlashQwenConfigBehavior = VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + **kwargs, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + behavior=behavior, + preprocessors=preprocessors, + ) + self._orig_config = config + if self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if not self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return {} + return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "num_frames", 3: "height", 4: "width"}} + + def with_behavior( + self, + behavior: Union[str, VideoChatFlashQwenConfigBehavior], + ): + """ + Creates a config for different behaviour. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, VideoChatFlashQwenConfigBehavior): + behavior = VideoChatFlashQwenConfigBehavior(behavior) + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_PROJECTION: + export_config = VideoChatFlashQWENProjectorOpenVINOConfig( + self._orig_config, + task="feature-extraction", + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + ) + return export_config + + if behavior == VideoChatFlashQwenConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config( + "qwen2", self._orig_config, self.int_dtype, self.float_dtype + ) + + if behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config( + "qwen2", self._orig_config, self.int_dtype, self.float_dtype + ) + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def get_model_for_behavior(self, model, behavior: Union[str, VideoChatFlashQwenConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, VideoChatFlashQwenConfigBehavior): + behavior = VideoChatFlashQwenConfigBehavior(behavior) + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_PROJECTION: + return model.get_model().mm_projector.mlp + + if behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return model.get_vision_tower().vision_tower + + if behavior == VideoChatFlashQwenConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.get_input_embeddings() + text_embedding.config = model.config + return text_embedding + + if behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE: + model.model.llm_compress_layer_list = [] + return model.language_model if not hasattr(model, "lm_head") else model + + def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None): + model_kwargs = model_kwargs or {} + if self._behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE: + return VideochatFlashQwenLanguageModelPatcher(self, model, model_kwargs) + + if self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS: + return VideochatFlashQwenVisionEmbeddingModelPatcher(self, model, model_kwargs) + + return super().patch_model_for_export(model, model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 56e550858c..163c8136a1 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7519,3 +7519,62 @@ def __exit__(self, exc_type, exc_value, traceback): afmoe_moe = layer.mlp afmoe_moe.forward = afmoe_moe._orig_forward del afmoe_moe.down_projs, afmoe_moe.gate_projs, afmoe_moe.up_projs + + +class VideochatFlashQwenVisionEmbeddingModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + def forward_wrap(self, pixel_values): + return self.__orig_forward(x=pixel_values) + + model.forward = types.MethodType(forward_wrap, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + +class VideochatFlashQwenLanguageModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + def forward_wrap( + self, + attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + ): + from transformers.cache_utils import DynamicCache + + outputs, labels = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + ) + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + output = (logits,) + outputs[1:] + return output + + model.forward = types.MethodType(forward_wrap, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index de92645017..c293783d77 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -302,6 +302,7 @@ def get_submodels(model): "phi4_multimodal", "llama4", "minicpmo", + "videochat_flash_qwen", ] SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid"] diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index d7f1ffe7d2..de9851b53a 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -4386,6 +4386,255 @@ def preprocess_inputs( return inputs +class _OVVideoChatFlashQwenForCausalLM(OVModelForVisualCausalLM): + additional_parts = ["vision_projection"] + + def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): + if input_ids is not None and input_ids.shape[1] == 1: + return None + image_features = self.vision_embeddings(pixel_values).last_hidden_state + image_features = self.multi_modal_projector(image_features) + return image_features + + def pack_image_features(self, image_features, image_sizes, image_newline=None): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Args: + image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`List[int]`) + token length of each image in image_features + """ + from transformers.models.llava_next_video.modeling_llava_next_video import ( + get_anyres_image_grid_shape, + unpad_image, + ) + + new_image_features = [] + feature_lens = [] + vision_feature_select_strategy = self.config.vision_feature_select_strategy + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + + if ( + np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 + and vision_feature_select_strategy == "default" + ): + logger.warning_once( + "Image feature shape does not line up with the provided patch size. " + "You may be using the `default` vision_feature_select_strategy with a" + " visual encoder that does not have CLS." + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) + return image_features, feature_lens + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + video: Optional["VideoInput"] = None, + audio: Optional[np.ndarray] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if audio is not None: + raise ValueError("Audio input is not supported") + if getattr(processor, "chat_template", None) is not None: + chat_prompt = [{"role": "user", "content": [{"type": "text", "text": text}]}] + if image is not None: + chat_prompt[0]["content"].append({"type": "image"}) + if video is not None: + chat_prompt[0]["content"].append({"type": "video"}) + prompt = processor.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False) + else: + prompt = text + if image is not None and "" not in prompt: + prompt = "\n" + prompt + if video is not None and "