-
Notifications
You must be signed in to change notification settings - Fork 213
support videochat #1637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
support videochat #1637
Changes from 30 commits
c4eb689
3748931
64fb029
e748336
4001249
553c49d
97c1226
74a8e5f
a0af467
70056d0
c5d0807
67f33c2
b44b15d
7bb536d
2cf85fd
2ad5818
74dcc9d
80df5c2
f7835a0
82e4c22
3a0e310
190a796
45724c4
77482af
87f5b70
c7245a6
e0b4251
60c9a4b
33ca875
303d875
7379ab3
554820d
2cd0214
e1dba19
0f1907d
8a2811d
9bca64a
ea1a5a0
33b17e8
b67cdf0
b4bdb50
72c94da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -200,6 +200,9 @@ | |||||
| Qwen3VLVisionEmbMergerPatcher, | ||||||
| QwenModelPatcher, | ||||||
| SanaTextEncoderModelPatcher, | ||||||
| VideochatFlashQwenLanguageModelPatcher, | ||||||
| VideochatFlashQwenVisionEmbeddingModelPatcher, | ||||||
| VideochatFlashQwenVisionProjectionModelPatcher, | ||||||
| XverseModelPatcher, | ||||||
| Zamba2ModelPatcher, | ||||||
| ) | ||||||
|
|
@@ -5303,6 +5306,211 @@ class SiglipTextOpenVINOConfig(SiglipTextOnnxConfig): | |||||
| pass | ||||||
|
|
||||||
|
|
||||||
| class DummyVideoChatFlashQwenInputGenerator(DummyVisionInputGenerator): | ||||||
| SUPPORTED_INPUT_NAMES = ("hidden_states", "rotary_pos_emb") | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| task: str, | ||||||
| normalized_config: NormalizedVisionConfig, | ||||||
| batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | ||||||
| num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], | ||||||
| width: int = DEFAULT_DUMMY_SHAPES["width"], | ||||||
| height: int = DEFAULT_DUMMY_SHAPES["height"], | ||||||
| visual_seq_length: int = DEFAULT_DUMMY_SHAPES["visual_seq_length"], | ||||||
| **kwargs, | ||||||
| ): | ||||||
| super().__init__(task, normalized_config, batch_size, num_channels, width, height, visual_seq_length, **kwargs) | ||||||
| self.num_frames = getattr(normalized_config.config, "mm_local_num_frames", 4) | ||||||
| self.embed_dim = getattr(normalized_config.config, "mm_hidden_size", 1408) | ||||||
| self.height = 224 | ||||||
| self.width = 224 | ||||||
|
Comment on lines
+5326
to
+5327
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make a comment why they are fixed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
| self.image_size = (self.height, self.width) | ||||||
| self.patch_size = 14 | ||||||
|
|
||||||
| def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | ||||||
| if input_name == "hidden_states": | ||||||
| return self.random_float_tensor( | ||||||
| shape=[ | ||||||
| self.batch_size, | ||||||
| self.num_channels, | ||||||
| self.num_frames, | ||||||
| self.height, | ||||||
| self.width, | ||||||
| ], | ||||||
| framework=framework, | ||||||
| dtype=float_dtype, | ||||||
| ) | ||||||
| elif input_name == "rotary_pos_emb": | ||||||
| grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size | ||||||
| grid_t = self.num_frames | ||||||
| return self.random_float_tensor( | ||||||
| [1, 1 + grid_h * grid_t * grid_w, self.embed_dim], framework=framework, dtype=float_dtype | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class DummyVideoChatFlashQwenProjectorInputGenerator(DummyInputGenerator): | ||||||
| SUPPORTED_INPUT_NAMES = ["hidden_states"] | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| task: str, | ||||||
| normalized_config: NormalizedTextConfig, | ||||||
| batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | ||||||
| random_batch_size_range: Optional[Tuple[int, int]] = None, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| self.task = task | ||||||
| self.batch_size = batch_size | ||||||
| self.hidden_size = normalized_config.config.mm_hidden_size | ||||||
| self.num_patches = 64 | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make a comment why it is fixed.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
| self.normalized_config = normalized_config | ||||||
|
|
||||||
| def generate( | ||||||
| self, | ||||||
| input_name: str, | ||||||
| framework: str = "pt", | ||||||
| int_dtype: str = "int64", | ||||||
| float_dtype: str = "fp32", | ||||||
| ): | ||||||
| shape = [self.batch_size, self.num_patches, self.hidden_size] | ||||||
| return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) | ||||||
|
|
||||||
|
|
||||||
| class VideoChatFlashQwenProjectorOpenVINOConfig(OnnxConfig): | ||||||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenProjectorInputGenerator,) | ||||||
| NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig | ||||||
|
|
||||||
| @property | ||||||
| def inputs(self) -> Dict[str, Dict[int, str]]: | ||||||
| return {"hidden_states": {0: "batch_size", 1: "num_patches", 2: "hidden_size"}} | ||||||
|
|
||||||
| def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None): | ||||||
| model_kwargs = model_kwargs or {} | ||||||
| return VideochatFlashQwenVisionProjectionModelPatcher(self, model, model_kwargs) | ||||||
|
|
||||||
|
|
||||||
| class VideoChatFlashQwenConfigBehavior(str, enum.Enum): | ||||||
| LANGUAGE = "language" | ||||||
| VISION_EMBEDDINGS = "vision_embeddings" | ||||||
| VISION_PROJECTION = "vision_projection" | ||||||
| TEXT_EMBEDDINGS = "text_embeddings" | ||||||
|
|
||||||
|
|
||||||
| @register_in_tasks_manager("videochat_flash_qwen", *["image-text-to-text"], library_name="transformers") | ||||||
| class VideoChatFlashQwenOpenVINOConfig(BaseVLMOpenVINOConfig): | ||||||
| MIN_TRANSFORMERS_VERSION = "4.45.0" | ||||||
|
||||||
| MIN_TRANSFORMERS_VERSION = "4.45.0" | |
| MIN_TRANSFORMERS_VERSION = "4.49.0" |
let use this limitation and remove below check with exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strange that num_channels is dynamic, please double-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_channels has been removed.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7640,6 +7640,138 @@ def __exit__(self, exc_type, exc_value, traceback): | |
| del afmoe_moe.down_projs, afmoe_moe.gate_projs, afmoe_moe.up_projs | ||
|
|
||
|
|
||
| class VideochatFlashQwenVisionEmbeddingModelPatcher(ModelPatcher): | ||
| def __init__( | ||
| self, | ||
| config: "OnnxConfig", | ||
| model: "PreTrainedModel", | ||
| model_kwargs: Dict[str, Any] = None, | ||
| ): | ||
| model.__orig_forward = model.forward | ||
|
|
||
| def forward_wrap(self, hidden_states, rotary_pos_emb=None, mask=None, use_image=False): | ||
| hidden_states = self.patch_embed(hidden_states.type(self.dtype)) | ||
| B, T, L, C = hidden_states.shape # T: temporal; L: spatial | ||
| hidden_states = hidden_states.view([B, T * L, C]) | ||
|
|
||
| # append cls token | ||
| cls_tokens = self.cls_token.expand(B, -1, -1) | ||
| hidden_states = torch.cat((cls_tokens, hidden_states), dim=1) | ||
|
|
||
| # add pos_embed | ||
| if self.sep_pos_embed: | ||
| raise NotImplementedError | ||
|
||
| else: | ||
| if use_image: | ||
| if self.sep_image_video_pos_embed: | ||
| rotary_pos_emb = self.img_pos_embed | ||
| else: | ||
| # (1, num_img_patches + 1, embed_dim) | ||
| cls_pos_embed = self.pos_embed[:, 0:1, :] | ||
|
|
||
| img_pos_embed = ( | ||
| self.pos_embed[:, 1:, :] | ||
| .view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim) | ||
| .mean(dim=1) | ||
| ) | ||
|
|
||
| rotary_pos_emb = torch.cat([cls_pos_embed, img_pos_embed], dim=1) | ||
| else: | ||
| if rotary_pos_emb is None: | ||
| rotary_pos_emb = self.pos_embed | ||
|
|
||
| hidden_states = hidden_states + rotary_pos_emb | ||
|
|
||
| # mask tokens, ~mask means visible | ||
| if mask is not None: | ||
| hidden_states = hidden_states[~mask].reshape(B, -1, C) | ||
| else: | ||
| hidden_states = hidden_states.reshape(B, -1, C) | ||
|
|
||
| residual = None | ||
|
|
||
| for idx, blk in enumerate(self.blocks): | ||
| if isinstance(hidden_states, tuple) and len(hidden_states) == 2: | ||
| hidden_states, residual = hidden_states | ||
| hidden_states = blk(hidden_states, residual=residual) | ||
|
|
||
| if isinstance(hidden_states, tuple) and len(hidden_states) == 2: | ||
| hidden_states, residual = hidden_states | ||
| if residual is not None: | ||
|
||
| hidden_states = hidden_states + residual | ||
|
|
||
| x_vis = hidden_states | ||
| if self.x_vis_only: | ||
| return x_vis | ||
| else: | ||
| x_pool_vis = self.clip_projector(x_vis) | ||
| return x_vis, x_pool_vis, None, None | ||
|
|
||
| model.forward = types.MethodType(forward_wrap, model) | ||
| super().__init__(config, model, model_kwargs) | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
| self._model.forward = self._model.__orig_forward | ||
|
|
||
|
|
||
| class VideochatFlashQwenVisionProjectionModelPatcher(ModelPatcher): | ||
| def __init__( | ||
| self, | ||
| config: "OnnxConfig", | ||
| model: "PreTrainedModel", | ||
| model_kwargs: Dict[str, Any] = None, | ||
| ): | ||
| model.__orig_forward = model.forward | ||
|
||
|
|
||
| def forward_wrap(self, hidden_states): | ||
| return self.__orig_forward(input=hidden_states) | ||
|
|
||
| model.forward = types.MethodType(forward_wrap, model) | ||
| super().__init__(config, model, model_kwargs) | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
| self._model.forward = self._model.__orig_forward | ||
|
|
||
|
|
||
| class VideochatFlashQwenLanguageModelPatcher(ModelPatcher): | ||
| def __init__( | ||
| self, | ||
| config: "OnnxConfig", | ||
| model: "PreTrainedModel", | ||
| model_kwargs: Dict[str, Any] = None, | ||
| ): | ||
| model.__orig_forward = model.forward | ||
|
|
||
| def forward_wrap( | ||
| self, | ||
| attention_mask, | ||
| position_ids=None, | ||
| past_key_values=None, | ||
| inputs_embeds=None, | ||
| ): | ||
| outputs, _ = self.model( | ||
| input_ids=None, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
| hidden_states = outputs[0] | ||
| logits = self.lm_head(hidden_states) | ||
| logits = logits.float() | ||
| output = (logits,) + outputs[1:] | ||
| return output | ||
|
|
||
| model.forward = types.MethodType(forward_wrap, model) | ||
| super().__init__(config, model, model_kwargs) | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
| self._model.forward = self._model.__orig_forward | ||
|
|
||
|
|
||
| # adopted from https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/models/llama/modeling_llama.py#L197 | ||
| class LlamaEagle3Attention(LlamaAttention): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.