Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,18 @@ def main_export(
config.audio_processor["config"]["activation_checkpointing"] = ""
config._attn_implementation = "sdpa"
loading_kwargs["config"] = config

# Handle FP8 quantized models (e.g. Ministral-3B FP8) by dequantizing to BF16
quant_cfg = getattr(config, "quantization_config", None)
if quant_cfg is not None and getattr(quant_cfg, "quant_method", None) == "fp8":
try:
from transformers import FineGrainedFP8Config

loading_kwargs["quantization_config"] = FineGrainedFP8Config(dequantize=True)
except (ImportError, Exception):
# If FineGrainedFP8Config not available, strip quantization to avoid errors
config.quantization_config = None
loading_kwargs["config"] = config
# there are some difference between remote and in library representation of past key values for some models,
# for avoiding confusion we disable remote code for them
if (
Expand Down
115 changes: 115 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@
MiniCPMModelPatcher,
MiniCPMVImageEmbeddingsModelPatcher,
MiniCPMVResamplerModelPatcher,
Mistral3ImageEmbeddingModelPatcher,
Mistral3LanguageModelPatcher,
MistralModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
Expand Down Expand Up @@ -288,6 +290,22 @@ def init_model_configs():
"AutoModelForImageTextToText",
)

TasksManager._CUSTOM_CLASSES[("pt", "mistral3", "image-text-to-text")] = (
"transformers",
"Mistral3ForConditionalGeneration",
)

# Register "ministral3" text config type so Mistral3Config can instantiate its text sub-config
try:
from transformers.models.auto.configuration_auto import CONFIG_MAPPING

if "ministral3" not in CONFIG_MAPPING:
from transformers.models.ministral.configuration_ministral import MinistralConfig

CONFIG_MAPPING.register("ministral3", MinistralConfig, exist_ok=True)
except Exception:
pass

if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["fill"] = {"flux": "FluxFillPipeline"}
Expand Down Expand Up @@ -4547,6 +4565,103 @@ def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[
return Llama4ImageEmbeddingsModelPatcher(self, model, model_kwargs)


@register_in_tasks_manager("mistral3", *["image-text-to-text"], library_name="transformers")
class Mistral3OpenVINOConfig(BaseVLMOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.50.0"

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: VLMConfigBehavior = VLMConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
**kwargs,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)
self._orig_config = config
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

def with_behavior(
self,
behavior: Union[str, VLMConfigBehavior],
):
if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior):
behavior = VLMConfigBehavior(behavior)

if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS:
model_type = self._orig_config.text_config.model_type
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
model_type = "mistral"
return get_vlm_text_embeddings_config(
model_type, self._orig_config.text_config, self.int_dtype, self.float_dtype
)

if behavior == VLMConfigBehavior.LANGUAGE:
model_type = self._orig_config.text_config.model_type
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
model_type = "mistral"
return get_vlm_text_generation_config(
model_type, self._orig_config.text_config, self.int_dtype, self.float_dtype,
model_patcher=Mistral3LanguageModelPatcher,
)

if behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior):
behavior = VLMConfigBehavior(behavior)

if behavior == VLMConfigBehavior.LANGUAGE:
return model

if behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return model

if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS:
if hasattr(model, "model") and hasattr(model.model, "language_model"):
text_embedding = model.model.language_model.get_input_embeddings()
text_embedding.config = model.model.language_model.config
else:
text_embedding = model.get_input_embeddings()
text_embedding.config = model.config.text_config
return text_embedding

def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None):
model_kwargs = model_kwargs or {}
if self._behavior != VLMConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)
return Mistral3ImageEmbeddingModelPatcher(self, model, model_kwargs)

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return {"last_hidden_state": {0: "num_patches"}}
return super().outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
kwargs["batch_size"] = 1
return super().generate_dummy_inputs(framework, **kwargs)


class MambaCacheDummyInputGenerator(DummyInputGenerator):
"""
Generates dummy past_ssm_states, past_conv_states and cache_position inputs for Mamba architectures.
Expand Down
125 changes: 125 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8319,3 +8319,128 @@ def __exit__(self, exc_type, exc_value, traceback):
sparse_moe_block = decoder_layer.mlp
decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward
del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs


def _mistral3_vision_embed_forward(self, pixel_values):
"""
Full vision pipeline for Mistral3 export: vision_tower + multi_modal_projector.
Inlines PixtralVisionModel + PatchMerger to keep all shapes derived from
pixel_values.shape, ensuring dynamic dimensions in the OpenVINO IR.

The standard PixtralVisionModel.forward uses Python lists, torch.arange on
Python ints from image_sizes, and split_with_sizes — all of which get baked
as constants during torch tracing. This rewrite avoids all of those.
"""
vision_tower = self.model.vision_tower
patch_size = vision_tower.patch_size
max_width = vision_tower.config.image_size // patch_size

# Step 1: Patch convolution — [1, 3, H, W] → [1, hidden, h_patches, w_patches]
patch_embeds = vision_tower.patch_conv(pixel_values)
h_patches = patch_embeds.shape[2]
w_patches = patch_embeds.shape[3]

# Step 2: Flatten to sequence — [1, n_patches, hidden]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
patch_embeds = vision_tower.ln_pre(patch_embeds)

# Step 3: Build 2D position IDs from patch grid shape (stays dynamic)
h_idx = torch.arange(h_patches, device=pixel_values.device)
w_idx = torch.arange(w_patches, device=pixel_values.device)
mesh_h, mesh_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
position_ids = (mesh_h.reshape(-1) * max_width + mesh_w.reshape(-1))

# Step 4: Compute RoPE position embeddings
position_embeddings = vision_tower.patch_positional_embedding(patch_embeds, position_ids)

# Step 5: Run transformer (no attention mask needed for single image)
transformer_out = vision_tower.transformer(
patch_embeds,
attention_mask=None,
position_embeddings=position_embeddings,
output_hidden_states=True,
return_dict=True,
)

# Step 6: Select vision feature layer
vision_feature_layer = self.config.vision_feature_layer
if isinstance(vision_feature_layer, int):
selected_image_feature = transformer_out.hidden_states[vision_feature_layer]
else:
hs_pool = [transformer_out.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)

# Step 7: Multi-modal projector with inlined PatchMerger
projector = self.model.multi_modal_projector
image_features = projector.norm(selected_image_feature.squeeze(0))

spatial_merge = projector.patch_merger.spatial_merge_size
d = image_features.shape[-1]

image_grid = image_features.view(h_patches, w_patches, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(image_grid, kernel_size=spatial_merge, stride=spatial_merge)
grid = grid.view(d * spatial_merge ** 2, -1).t()

image_features = projector.patch_merger.merging_layer(grid)
image_features = projector.linear_1(image_features)
image_features = projector.act(image_features)
image_features = projector.linear_2(image_features)

return image_features


class Mistral3ImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(_mistral3_vision_embed_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Mistral3LanguageModelPatcher(OVDecoderModelPatcher):
"""
Patcher for the language model part of Mistral3 VLM.
Fixes sliding_window=None crash — MinistralModel.forward unconditionally creates
sliding window masks even when all layers use full_attention.
"""

def __enter__(self):
super().__enter__()
lang_model = None
if hasattr(self._model, "model") and hasattr(self._model.model, "language_model"):
lang_model = self._model.model.language_model
elif hasattr(self._model, "language_model"):
lang_model = self._model.language_model
elif hasattr(self._model, "model") and hasattr(self._model.model, "config"):
lang_model = self._model.model

if lang_model is not None and hasattr(lang_model, "config"):
cfg = lang_model.config
self._orig_sliding_window = getattr(cfg, "sliding_window", None)
if self._orig_sliding_window is None:
max_pos = getattr(cfg, "max_position_embeddings", 32768)
cfg.sliding_window = max_pos

return self

def __exit__(self, exc_type, exc_value, traceback):
lang_model = None
if hasattr(self._model, "model") and hasattr(self._model.model, "language_model"):
lang_model = self._model.model.language_model
elif hasattr(self._model, "language_model"):
lang_model = self._model.language_model
elif hasattr(self._model, "model") and hasattr(self._model.model, "config"):
lang_model = self._model.model

if lang_model is not None and hasattr(lang_model, "config"):
lang_model.config.sliding_window = self._orig_sliding_window

super().__exit__(exc_type, exc_value, traceback)
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_submodels(model):
"phi4_multimodal",
"llama4",
"minicpmo",
"mistral3",
]

SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"]
Expand Down
54 changes: 54 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4802,6 +4802,59 @@ def preprocess_inputs(
return inputs


class _OVMistral3ForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
if pixel_values.dtype != torch.float32:
pixel_values = pixel_values.to(torch.float32)
return self.vision_embeddings(pixel_values).last_hidden_state

def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds

image_token_id = getattr(self.config, "image_token_index", getattr(self.config, "image_token_id", 10))
special_image_mask = input_ids == image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.view(-1, image_features.shape[-1]).to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

return inputs_embeds, attention_mask, position_ids

@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")

conversation = [
{
"role": "user",
"content": [{"type": "text", "text": text}],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})

text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=text_prompt, return_tensors="pt")
return inputs


MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
Expand All @@ -4824,4 +4877,5 @@ def preprocess_inputs(
"llama4": _OVLlama4ForCausalLM,
"qwen3_vl": _OVQwen3VLForCausalLM,
"minicpmo": _OVMiniCPMOForCausalLM,
"mistral3": _OVMistral3ForCausalLM,
}