Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 524 additions & 0 deletions optimum/exporters/openvino/model_configs.py

Large diffs are not rendered by default.

281 changes: 280 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,7 @@ def _persimmon_self_attn_sdpa_forward(
fused_qkv = self.query_key_value(hidden_states)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
query_states, key_states, value_states = self._split_heads(fused_qkv)

if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
Expand Down Expand Up @@ -4090,6 +4090,24 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.forward = self._model.__orig_forward


def _deepstack_process_patched(self, hidden_states, visual_pos_masks, visual_embeds):
"""Trace-friendly replacement for _deepstack_process that avoids boolean indexing.

The original uses hidden_states[visual_pos_masks, :] which produces data-dependent shapes
that OpenVINO cannot handle. This uses cumsum + index_select + masked add instead.
"""
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
batch, seq_len, dim = hidden_states.shape
flat_mask = visual_pos_masks.reshape(-1)
indices = torch.cumsum(flat_mask.long(), dim=0) - 1
indices = torch.clamp(indices, min=0)
full_visual = torch.index_select(visual_embeds, 0, indices).reshape(batch, seq_len, dim)
mask_3d = flat_mask.to(hidden_states.dtype).reshape(batch, seq_len, 1)
hidden_states = hidden_states + full_visual * mask_3d
return hidden_states


class Qwen3VLLanguageModelPatcher(OVDecoderModelPatcher):
def __init__(
self,
Expand Down Expand Up @@ -4129,11 +4147,18 @@ def lm_forward(

model.__orig_forward = model.forward
model.forward = types.MethodType(lm_forward, model)

language_model = model.model.language_model
language_model.__orig_deepstack_process = language_model._deepstack_process
language_model._deepstack_process = types.MethodType(_deepstack_process_patched, language_model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
language_model = self._model.model.language_model
language_model._deepstack_process = language_model.__orig_deepstack_process


def patch_qwen2vl_vision_blocks(model, force_new_behaviour=False):
Expand Down Expand Up @@ -4392,6 +4417,260 @@ def __exit__(self, exc_type, exc_value, traceback):
block.attn.forward = block.attn._orig_forward


class Qwen3OmniVisionMergerPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

def image_embed_forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor:
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.merger_list[self.deepstack_visual_indexes.index(layer_num)](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
last_hidden_state = self.merger(hidden_states)
return last_hidden_state, torch.stack(deepstack_feature_lists, dim=0)

model.forward = types.MethodType(image_embed_forward, model)
super().__init__(config, model, model_kwargs)

def __enter__(self):
patch_qwen2vl_vision_blocks(self._model)
super().__enter__()

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward


class Qwen3OmniAudioEncoderPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

def audio_forward(
self,
padded_feature: torch.Tensor,
padded_mask_after_cnn: torch.Tensor,
aftercnn_lens: torch.Tensor,
) -> torch.Tensor:
padded_feature = padded_feature.unsqueeze(1)
padded_embed = torch.nn.functional.gelu(self.conv2d1(padded_feature))
padded_embed = torch.nn.functional.gelu(self.conv2d2(padded_embed))
padded_embed = torch.nn.functional.gelu(self.conv2d3(padded_embed))
b, c, f, t = padded_embed.size()
padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))

positional_embedding = (
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
.unsqueeze(0)
.to(padded_embed.dtype)
)
padded_embed = padded_embed + positional_embedding
# TODO: Boolean indexing and the cu_seqlens loop below produce data-dependent shapes
# that get baked in during export tracing. This works when inference input lengths match
# export dummy shapes but may fail for different audio lengths. A proper fix requires
# restructuring to use padded computation throughout the transformer layers.
hidden_states = padded_embed[padded_mask_after_cnn]

window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
cu_chunk_lens = [0]
for cnn_len in aftercnn_lens:
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
remainder = cnn_len % window_aftercnn
if remainder != 0:
cu_chunk_lens += [remainder]
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)

for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens)
hidden_states = layer_outputs[0]

hidden_states = self.ln_post(hidden_states)
hidden_states = self.proj1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.proj2(hidden_states)
return hidden_states

model.forward = types.MethodType(audio_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Qwen3OmniLanguageModelPatcher(OVDecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
# Determine which intermediate layer's hidden states the Talker needs
talker_config = getattr(model.config, "talker_config", None)
accept_hidden_layer = getattr(talker_config, "accept_hidden_layer", None)

def lm_forward(
self,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
visual_pos_masks,
deepstack_visual_embeds,
use_cache=True,
):
# Request all intermediate hidden states so we can extract the layer the Talker needs
pkv = DynamicCache.from_legacy_cache(past_key_values)
outputs = self.thinker.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_key_values=pkv,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
output_hidden_states=accept_hidden_layer is not None,
)
hidden_states = outputs[0]
logits = self.thinker.lm_head(hidden_states)
if accept_hidden_layer is not None:
intermediate_hidden_states = outputs.hidden_states[accept_hidden_layer]
return (logits, hidden_states, intermediate_hidden_states, outputs.past_key_values.to_legacy_cache())
return (logits, hidden_states, outputs.past_key_values.to_legacy_cache())

model.__orig_forward = model.forward
model.forward = types.MethodType(lm_forward, model)

thinker_model = model.thinker.model
thinker_model.__orig_deepstack_process = thinker_model._deepstack_process
thinker_model._deepstack_process = types.MethodType(_deepstack_process_patched, thinker_model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
thinker_model = self._model.thinker.model
thinker_model._deepstack_process = thinker_model.__orig_deepstack_process


class Qwen3OmniTalkerLanguageModelPatcher(OVDecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
def lm_forward(self, inputs_embeds, attention_mask, position_ids, past_key_values, use_cache=True):
pkv = DynamicCache.from_legacy_cache(past_key_values)
outputs = self.talker.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_key_values=pkv,
)
hidden_states = outputs[0]
logits = self.talker.codec_head(hidden_states)
return (logits, hidden_states, outputs.past_key_values.to_legacy_cache())

model.__orig_forward = model.forward
model.forward = types.MethodType(lm_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Qwen3OmniCodePredictorPatcher(OVDecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
code_predictor = model.talker.code_predictor
# Cast code_predictor to float32 to match dummy input dtypes during tracing
code_predictor.float()

# Stack all lm_head weights for index-based dispatch during tracing
stacked_heads = torch.stack([head.weight for head in code_predictor.lm_head])

def cp_forward(
self,
inputs_embeds,
attention_mask,
position_ids,
past_key_values,
generation_steps,
use_cache=True,
):
pkv = DynamicCache.from_legacy_cache(past_key_values)
outputs = self.talker.code_predictor.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_key_values=pkv,
)
hidden_states = outputs[0]
# Select the correct lm_head weights based on generation_steps
head_weight = stacked_heads[generation_steps]
logits = torch.nn.functional.linear(hidden_states, head_weight)
return (logits, hidden_states, outputs.past_key_values.to_legacy_cache())

model.__orig_forward = model.forward
model.forward = types.MethodType(cp_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Qwen3OmniCode2WavPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs=model_kwargs or {})

def __enter__(self):
super().__enter__()
# Patch _get_extra_padding_for_conv1d to avoid math.ceil on dynamic shapes
# For all conv configs in Code2Wav, extra_padding is always 0
import transformers.models.qwen3_omni.modeling_qwen3_omni as qwen3_omni_module

self._orig_get_extra_padding = qwen3_omni_module.Qwen3OmniCausalConvNet._get_extra_padding_for_conv1d
qwen3_omni_module.Qwen3OmniCausalConvNet._get_extra_padding_for_conv1d = lambda self, hidden_state: 0
return self

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
import transformers.models.qwen3_omni.modeling_qwen3_omni as qwen3_omni_module

qwen3_omni_module.Qwen3OmniCausalConvNet._get_extra_padding_for_conv1d = self._orig_get_extra_padding


# copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
def _granite_moe_topk_gating_forward(self, hidden_states):
# compute the top_k routing decision
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_submodels(model):
"phi4_multimodal",
"llama4",
"minicpmo",
"qwen3_omni",
]

SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"]
Expand Down
Loading