Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c4eb689
support videochat_flash_qwen
xufang-lisa Feb 9, 2026
3748931
fix error
xufang-lisa Feb 10, 2026
64fb029
remove unused function
xufang-lisa Mar 5, 2026
e748336
fix hidden_size for vision projectio
xufang-lisa Mar 6, 2026
4001249
Merge branch 'main' into xufang/support_videochat
xufang-lisa Mar 6, 2026
553c49d
add preprocess_inputs
xufang-lisa Mar 12, 2026
97c1226
set default quantization config for videochat model
xufang-lisa Mar 14, 2026
74a8e5f
add rotary_pos_embed to vision_embedding
xufang-lisa Mar 14, 2026
a0af467
update vision projection input name
xufang-lisa Mar 14, 2026
70056d0
use mm_hidden_size as embed_dim
xufang-lisa Mar 16, 2026
c5d0807
Merge branch 'main' into xufang/support_videochat
xufang-lisa Mar 17, 2026
67f33c2
add check for videochat
xufang-lisa Mar 17, 2026
b44b15d
Add pipeline for VideoChat
xufang-lisa Mar 18, 2026
7bb536d
support text only
xufang-lisa Mar 18, 2026
2cf85fd
remove unused code
xufang-lisa Mar 18, 2026
2ad5818
add videochat test
xufang-lisa Mar 18, 2026
74dcc9d
add test dependencies
xufang-lisa Mar 19, 2026
80df5c2
fix style check issue
xufang-lisa Mar 19, 2026
f7835a0
fix style check issue
xufang-lisa Mar 23, 2026
82e4c22
Apply suggestions from code review
xufang-lisa Mar 24, 2026
3a0e310
apply code review comments
xufang-lisa Mar 24, 2026
190a796
fix code style
xufang-lisa Mar 24, 2026
45724c4
fix fail in 4.45
xufang-lisa Mar 24, 2026
77482af
update _DEFAULT_4BIT_WQ_CONFIGS
xufang-lisa Mar 24, 2026
87f5b70
add test to compare inference results with transformers
xufang-lisa Mar 25, 2026
c7245a6
apply comments
xufang-lisa Mar 25, 2026
e0b4251
fix test fail in transformers 4.45
xufang-lisa Mar 26, 2026
60c9a4b
pad frame number to a multiple of 4
xufang-lisa Mar 26, 2026
33ca875
fix file locking issue on windows
xufang-lisa Mar 27, 2026
303d875
apply comments
xufang-lisa Mar 31, 2026
7379ab3
update comments
xufang-lisa Mar 31, 2026
554820d
Merge branch 'main' into xufang/support_videochat
xufang-lisa Mar 31, 2026
2cd0214
fix code style
xufang-lisa Apr 1, 2026
e1dba19
add default image_preprocess
xufang-lisa Apr 3, 2026
0f1907d
apply comments
xufang-lisa Apr 3, 2026
8a2811d
fix code style
xufang-lisa Apr 3, 2026
9bca64a
update tests
xufang-lisa Apr 3, 2026
ea1a5a0
update tests
xufang-lisa Apr 3, 2026
33b17e8
remove NotImplemented exceptions
xufang-lisa Apr 3, 2026
b67cdf0
fix code style
xufang-lisa Apr 3, 2026
b4bdb50
test videochat export when transformers>=4.49
xufang-lisa Apr 4, 2026
72c94da
fix code style
xufang-lisa Apr 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,23 @@ def run(self):
get_default_quantization_config,
)
from ...intel.openvino.utils import TemporaryDirectory
from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.import_utils import is_nncf_available, is_transformers_version
from ...intel.utils.modeling_utils import _infer_library_from_model_name_or_path
import os

is_local = os.path.isdir(self.args.model)
if (
"OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B" in self.args.model
and not is_local
and (is_transformers_version(">=", "4.49"))
):
raise ValueError(
"The model OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B in hugging face "
"contains custom code and requires transformers version prior to 4.49. "
"It is recommended to install transformers version 4.48 in your environment or download "
"https://modelscope.cn/models/OpenGVLab/VideoChat-Flash-Qwen2_5-7B_InternVideo2-1B "
"to your local path and use local path to convert."
)

if self.args.library is None:
# TODO: add revision, subfolder and token to args
Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ def export_pytorch(
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
if hasattr(model, "config"):
setattr(model.config, override_config_key, override_config_value)

if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES
Expand Down
202 changes: 202 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@
Qwen3VLVisionEmbMergerPatcher,
QwenModelPatcher,
SanaTextEncoderModelPatcher,
VideochatFlashQwenLanguageModelPatcher,
VideochatFlashQwenVisionEmbeddingModelPatcher,
VideochatFlashQwenVisionProjectionModelPatcher,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VideochatFlashQwenLanguageModelPatcher,
VideochatFlashQwenVisionEmbeddingModelPatcher,
VideochatFlashQwenVisionProjectionModelPatcher,
VideoChatFlashQwenLanguageModelPatcher,
VideoChatFlashQwenVisionEmbeddingModelPatcher,
VideoChatFlashQwenVisionProjectionModelPatcher,

for consistency

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

XverseModelPatcher,
Zamba2ModelPatcher,
)
Expand Down Expand Up @@ -5303,6 +5306,205 @@ class SiglipTextOpenVINOConfig(SiglipTextOnnxConfig):
pass


class DummyVideoChatFlashQwenInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states", "rotary_pos_emb")

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
visual_seq_length: int = DEFAULT_DUMMY_SHAPES["visual_seq_length"],
**kwargs,
):
super().__init__(task, normalized_config, batch_size, num_channels, width, height, visual_seq_length, **kwargs)
if hasattr(normalized_config, "config") and hasattr(normalized_config.config, "mm_local_num_frames"):
self.num_frames = normalized_config.config.mm_local_num_frames
self.height = 224
self.width = 224
Comment on lines +5326 to +5327
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make a comment why they are fixed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self.image_size = (self.height, self.width)
self.patch_size = 14
self.embed_dim = normalized_config.config.mm_hidden_size

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "hidden_states":
return self.random_float_tensor(
shape=[
self.batch_size,
self.num_channels,
self.num_frames,
self.height,
self.width,
],
framework=framework,
dtype=float_dtype,
)
elif input_name == "rotary_pos_emb":
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
grid_t = self.num_frames
return self.random_float_tensor(
[1, 1 + grid_h * grid_t * grid_w, self.embed_dim], framework=framework, dtype=float_dtype
)


class DummyVideoChatFlashQwenProjectorInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ["hidden_states"]

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.hidden_size = normalized_config.mm_hidden_size
self.num_patches = 64
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make a comment why it is fixed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self.normalized_config = normalized_config

def generate(
self,
input_name: str,
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
shape = [self.batch_size, self.num_patches, self.hidden_size]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


class VideoChatFlashQWENProjectorOpenVINOConfig(OnnxConfig):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class VideoChatFlashQWENProjectorOpenVINOConfig(OnnxConfig):
class VideoChatFlashQwenProjectorOpenVINOConfig(OnnxConfig):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenProjectorInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"hidden_states": {0: "batch_size", 1: "num_patches", 2: "hidden_size"}}

def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None):
model_kwargs = model_kwargs or {}
return VideochatFlashQwenVisionProjectionModelPatcher(self, model, model_kwargs)


class VideoChatFlashQwenConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
VISION_EMBEDDINGS = "vision_embeddings"
VISION_PROJECTION = "vision_projection"
TEXT_EMBEDDINGS = "text_embeddings"


@register_in_tasks_manager("videochat_flash_qwen", *["image-text-to-text"], library_name="transformers")
class VideoChatFlashQwenOpenVINOConfig(BaseVLMOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.42.0"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model from HF: Min?, Max 4.48.x
Model from modescope: Min 4.46?, Max 4.57.x

Suggested change
MIN_TRANSFORMERS_VERSION = "4.42.0"
MIN_TRANSFORMERS_VERSION = "For both"
MAX_TRANSFORMERS_VERSION = "4.57.99"

Throw exception in constructor

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update MIN_TRANSFORMERS_VERSION and MAX_TRANSFORMERS_VERSION.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add exception in constructor.

SUPPORTED_BEHAVIORS = [model_type.value for model_type in VideoChatFlashQwenConfigBehavior]
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoChatFlashQwenInputGenerator,)

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: VideoChatFlashQwenConfigBehavior = VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
**kwargs,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
behavior=behavior,
preprocessors=preprocessors,
)
self._orig_config = config
if self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if not self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS:
return {}
return {
"hidden_states": {0: "batch_size", 1: "num_channels", 2: "num_frames", 3: "height", 4: "width"},
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strange that num_channels is dynamic, please double-check

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_channels has been removed.

"rotary_pos_emb": {0: "batch_size", 1: "num_tokens", 2: "hidden_size"},
}

def with_behavior(
self,
behavior: Union[str, VideoChatFlashQwenConfigBehavior],
):
"""
Creates a config for different behaviour.

Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
"""
if isinstance(behavior, str) and not isinstance(behavior, VideoChatFlashQwenConfigBehavior):
behavior = VideoChatFlashQwenConfigBehavior(behavior)

if behavior == VideoChatFlashQwenConfigBehavior.VISION_PROJECTION:
export_config = VideoChatFlashQWENProjectorOpenVINOConfig(
self._orig_config,
task="feature-extraction",
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
return export_config

if behavior == VideoChatFlashQwenConfigBehavior.TEXT_EMBEDDINGS:
return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype)

if behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE:
return get_vlm_text_generation_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype)

if behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def get_model_for_behavior(self, model, behavior: Union[str, VideoChatFlashQwenConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, VideoChatFlashQwenConfigBehavior):
behavior = VideoChatFlashQwenConfigBehavior(behavior)

if behavior == VideoChatFlashQwenConfigBehavior.VISION_PROJECTION:
return model.get_model().mm_projector.mlp

if behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS:
return model.get_vision_tower().vision_tower

if behavior == VideoChatFlashQwenConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.get_input_embeddings()
text_embedding.config = model.config
return text_embedding

if behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE:
model.model.llm_compress_layer_list = []
return model.language_model if not hasattr(model, "lm_head") else model

def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None):
model_kwargs = model_kwargs or {}
if self._behavior == VideoChatFlashQwenConfigBehavior.LANGUAGE:
return VideochatFlashQwenLanguageModelPatcher(self, model, model_kwargs)

if self._behavior == VideoChatFlashQwenConfigBehavior.VISION_EMBEDDINGS:
return VideochatFlashQwenVisionEmbeddingModelPatcher(self, model, model_kwargs)

return super().patch_model_for_export(model, model_kwargs)


@register_in_tasks_manager(
"hunyuan_v1_dense",
*[
Expand Down
134 changes: 134 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7640,6 +7640,140 @@ def __exit__(self, exc_type, exc_value, traceback):
del afmoe_moe.down_projs, afmoe_moe.gate_projs, afmoe_moe.up_projs


class VideochatFlashQwenVisionEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

def forward_wrap(self, hidden_states, rotary_pos_emb=None, mask=None, use_image=False):
hidden_states = self.patch_embed(hidden_states.type(self.dtype))
B, T, L, C = hidden_states.shape # T: temporal; L: spatial
hidden_states = hidden_states.view([B, T * L, C])

# append cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)

# add pos_embed
if self.sep_pos_embed:
raise NotImplementedError
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that this is needed. Please clean code in this patcher.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

else:
if use_image:
if self.sep_image_video_pos_embed:
rotary_pos_emb = self.img_pos_embed
else:
# (1, num_img_patches + 1, embed_dim)
cls_pos_embed = self.pos_embed[:, 0:1, :]

img_pos_embed = (
self.pos_embed[:, 1:, :]
.view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim)
.mean(dim=1)
)

rotary_pos_emb = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
else:
if rotary_pos_emb is None:
rotary_pos_emb = self.pos_embed

hidden_states = hidden_states + rotary_pos_emb

# mask tokens, ~mask means visible
if mask is not None:
hidden_states = hidden_states[~mask].reshape(B, -1, C)
else:
hidden_states = hidden_states.reshape(B, -1, C)

residual = None

for idx, blk in enumerate(self.blocks):
if isinstance(hidden_states, tuple) and len(hidden_states) == 2:
hidden_states, residual = hidden_states
hidden_states = blk(hidden_states, residual=residual)

if isinstance(hidden_states, tuple) and len(hidden_states) == 2:
hidden_states, residual = hidden_states
if residual is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this check? Is any model were residial is None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

residual is always None and has been removed. The type of hidden_states is tensor not tuple and this check has also been removed.

hidden_states = hidden_states + residual

x_vis = hidden_states
if self.x_vis_only:
return x_vis
else:
x_pool_vis = self.clip_projector(x_vis)
return x_vis, x_pool_vis, None, None

model.forward = types.MethodType(forward_wrap, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class VideochatFlashQwenVisionProjectionModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this patching? You just need to use original names and patching is not needed

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this patch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def forward_wrap(self, hidden_states):
return self.__orig_forward(input=hidden_states)

model.forward = types.MethodType(forward_wrap, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class VideochatFlashQwenLanguageModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

def forward_wrap(
self,
attention_mask,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
):
from transformers.cache_utils import DynamicCache

outputs, labels = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
output = (logits,) + outputs[1:]
return output

model.forward = types.MethodType(forward_wrap, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


# adopted from https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/models/llama/modeling_llama.py#L197
class LlamaEagle3Attention(LlamaAttention):
"""
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_submodels(model):
"phi4_multimodal",
"llama4",
"minicpmo",
"videochat_flash_qwen",
]

SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"]
Expand Down
Loading
Loading