Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@
Qwen2VLLanguageModelPatcher,
Qwen2VLVisionEmbMergerPatcher,
Qwen3MoeModelPatcher,
Qwen3_5MoeModelPatcher,
Qwen3_5MoeTextModelPatcher,
Qwen3_5TextModelPatcher,
Qwen3NextModelPatcher,
Qwen3VLLanguageModelPatcher,
Qwen3VLVisionEmbMergerPatcher,
Expand Down Expand Up @@ -5451,3 +5454,41 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
)

return dummy_inputs


# Qwen3.5 configs — inherit Qwen3Next's hybrid cache/input handling since the
# architecture is structurally identical (GatedDeltaNet + full attention layers).
@register_in_tasks_manager(
"qwen3_5",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3_5OpenVINOConfig(Qwen3NextOpenVINOConfig):
_MODEL_PATCHER = Qwen3_5TextModelPatcher


@register_in_tasks_manager(
"qwen3_5_text",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3_5TextOpenVINOConfig(Qwen3NextOpenVINOConfig):
_MODEL_PATCHER = Qwen3_5TextModelPatcher


@register_in_tasks_manager(
"qwen3_5_moe",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3_5MoEOpenVINOConfig(Qwen3NextOpenVINOConfig):
_MODEL_PATCHER = Qwen3_5MoeTextModelPatcher


@register_in_tasks_manager(
"qwen3_5_moe_text",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3_5MoETextOpenVINOConfig(Qwen3NextOpenVINOConfig):
_MODEL_PATCHER = Qwen3_5MoeTextModelPatcher
303 changes: 303 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8319,3 +8319,306 @@ def __exit__(self, exc_type, exc_value, traceback):
sparse_moe_block = decoder_layer.mlp
decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward
del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs

# Qwen3.5 MoE forward patch — replaces dynamic expert loop with static iteration
# The original uses torch.no_grad() + nonzero() + dynamic for loop which breaks OV tracing
def qwen3_5_moe_experts_forward_patched(self, hidden_states, top_k_index, top_k_weights):
import torch.nn as nn
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
for expert_idx in range(self.num_experts):
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states


# Patched GatedDeltaNet forward for Qwen3.5 — removes dynamic control flow for OV tracing.
# Key differences from Qwen3Next version:
# - Uses separate projections: in_proj_qkv, in_proj_z, in_proj_b, in_proj_a
# - No fix_query_key_value_ordering step
# - Cache uses self.layer_idx directly (not remapped through linear_attn_mapping)
def qwen3_5_gated_delta_net_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

batch_size, seq_len, _ = hidden_states.shape

layer_idx = self.layer_idx
recurrent_state = None
if cache_params is not None:
conv_state = cache_params.conv_states[layer_idx]
recurrent_state = cache_params.recurrent_states[layer_idx]

# Qwen3.5 uses separate projections (not combined like Qwen3Next)
mixed_qkv = self.in_proj_qkv(hidden_states)
mixed_qkv = mixed_qkv.transpose(1, 2)

z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)

b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

# Convolution — use ov_causal_conv1d for both prefill and decode (no branching)
if cache_params is not None:
new_mixed_qkv, new_conv_state = ov_causal_conv1d(conv_state, mixed_qkv, self.conv1d.weight, self.conv1d.bias)
mixed_qkv = F.silu(new_mixed_qkv)
cache_params.conv_states[layer_idx] = new_conv_state
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])

mixed_qkv = mixed_qkv.transpose(1, 2)
# Split QKV: the flat dim is [key_dim, key_dim, value_dim] contiguous blocks.
# Reshape to 4D BEFORE split to avoid 3D bf16 ops that the GPU plugin can't handle.
# Insert a dummy head dim of 1: [batch, seq, 1, total] → split on dim 3 → reshape per-tensor.
mixed_qkv = mixed_qkv.unsqueeze(2) # [batch, seq, 1, 6144] — now 4D
query = mixed_qkv[:, :, :, :self.key_dim] # [batch, seq, 1, key_dim]
key = mixed_qkv[:, :, :, self.key_dim:2 * self.key_dim] # [batch, seq, 1, key_dim]
value = mixed_qkv[:, :, :, 2 * self.key_dim:] # [batch, seq, 1, value_dim]

query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim)
key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim)
value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim)

beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

# Always use recurrent path — it handles both prefill and decode correctly
# and avoids the chunk_gated_delta_rule which has tracing issues
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
self,
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)

if cache_params is not None:
cache_params.recurrent_states[layer_idx] = last_recurrent_state

core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)

output = self.out_proj(core_attn_out)
return output


class Qwen3_5TextModelPatcher(OVDecoderModelPatcher):
"""
Patcher for Qwen3.5 hybrid (GatedDeltaNet + Attention) text models.
Handles cache decomposition, GatedDeltaNet forward patching, and
RecurrentAttentionCell OV conversion — following the same pattern as
Qwen3NextModelPatcher.
"""

def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Optional[Dict[str, Any]] = None,
):
from openvino.frontend.pytorch import ConversionExtension, ModuleExtension

super().__init__(config, model, model_kwargs)

# Detect cache class from the model's module
model_module = type(model).__module__
if "qwen3_5_moe" in model_module:
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeDynamicCache as CacheClass
else:
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache as CacheClass

# Build the cache wrapper class that decomposes flat tensor list
# into the structured cache the model expects.
# Unlike Qwen3Next which uses compact linear_attn_mapping/full_attn_mapping,
# Qwen3.5 stores conv_states/recurrent_states for ALL layers (None for attention layers).
class Qwen3_5DynamicCacheWrap(CacheClass):
def __init__(self, config, conv_states, recurrent_states, key_cache, value_cache):
super().__init__(config=config)
self.conv_states = conv_states
self.recurrent_states = recurrent_states
self.key_cache = key_cache
self.value_cache = value_cache

def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
# Cast cached states to match incoming dtype to prevent SDPA dtype mismatch
if self.key_cache[layer_idx] is not None:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.dtype)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.dtype)
return super().update(key_states, value_states, layer_idx, cache_kwargs)

def patched_forward(
input_ids,
attention_mask=None,
cache_params=None,
):
model_config = self.real_config._config
layer_types = model_config.layer_types
num_layers = len(layer_types)
num_linear = sum(1 for t in layer_types if t == "linear_attention")
num_full = sum(1 for t in layer_types if t == "full_attention")

use_cache = False
wrapped_cache_params = None
if cache_params is not None:
use_cache = True

# Flat cache layout: [conv_0, recur_0, conv_1, recur_1, ..., key_0, val_0, key_1, val_1, ...]
# First 2*num_linear entries are for linear attention layers
# Next 2*num_full entries are for full attention layers
linear_conv_states = []
linear_recurrent_states = []
for idx in range(num_linear):
linear_conv_states.append(cache_params[2 * idx])
linear_recurrent_states.append(cache_params[2 * idx + 1])

full_keys = []
full_values = []
offset = 2 * num_linear
for idx in range(num_full):
full_keys.append(cache_params[offset + 2 * idx])
full_values.append(cache_params[offset + 2 * idx + 1])

# Expand to per-layer lists (None for non-matching layers)
conv_states = [None] * num_layers
recurrent_states = [None] * num_layers
key_cache = [None] * num_layers
value_cache = [None] * num_layers

linear_idx = 0
full_idx = 0
for i, lt in enumerate(layer_types):
if lt == "linear_attention":
conv_states[i] = linear_conv_states[linear_idx]
recurrent_states[i] = linear_recurrent_states[linear_idx]
linear_idx += 1
else:
key_cache[i] = full_keys[full_idx]
value_cache[i] = full_values[full_idx]
full_idx += 1

wrapped_cache_params = Qwen3_5DynamicCacheWrap(
model_config, conv_states, recurrent_states, key_cache, value_cache
)

causal_lm_output = self.model_orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=wrapped_cache_params,
use_cache=use_cache,
)
outputs = {
"logits": causal_lm_output.logits,
}

if use_cache:
past_kv = causal_lm_output.past_key_values
# Re-flatten to same layout: linear states first, then attention KV
present_key_values = []
for i, lt in enumerate(layer_types):
if lt == "linear_attention":
present_key_values.append(past_kv.conv_states[i])
present_key_values.append(past_kv.recurrent_states[i])

for i, lt in enumerate(layer_types):
if lt != "linear_attention":
present_key_values.append(past_kv.key_cache[i])
present_key_values.append(past_kv.value_cache[i])

outputs["present_key_values"] = present_key_values

return outputs

self.patched_forward = patched_forward
self.model_orig_forward = self.orig_forward
self.orig_forward = patched_forward

# Reuse RecurrentAttentionCell and conversion extension from Qwen3Next
self.module_extensions = {
RecurrentAttentionCell: ModuleExtension(RecurrentAttentionCell, "RecurrentAttentionCellOp"),
}
self.conversion_extensions = [
ConversionExtension("RecurrentAttentionCellOp", convert_recurrent_attention_cell),
]

def __enter__(self):
super().__enter__()
setattr(self._model, self.orig_forward_name, self.patched_forward)

for idx, decoder_layer in enumerate(self._model.model.layers):
layer_type = self._model.model.config.layer_types[idx]
if layer_type == "linear_attention":
linear_attn_layer = decoder_layer.linear_attn
linear_attn_layer._orig_forward = linear_attn_layer.forward
linear_attn_layer.forward = types.MethodType(qwen3_5_gated_delta_net_forward, linear_attn_layer)
linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule
linear_attn_layer.recurrent_attention_cell = RecurrentAttentionCell()

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
setattr(self._model, self.orig_forward_name, self.model_orig_forward)
for idx, decoder_layer in enumerate(self._model.model.layers):
layer_type = self._model.model.config.layer_types[idx]
if layer_type == "linear_attention":
linear_attn_layer = decoder_layer.linear_attn
linear_attn_layer.forward = linear_attn_layer._orig_forward



class Qwen3_5MoeTextModelPatcher(Qwen3_5TextModelPatcher):
"""
Patcher for Qwen3.5 MoE hybrid models (35B-A3B).
Extends Qwen3_5TextModelPatcher with MoE expert forward patching.
"""

def __enter__(self):
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts

super().__enter__()
self._original_experts_forward = Qwen3_5MoeExperts.forward
Qwen3_5MoeExperts.forward = qwen3_5_moe_experts_forward_patched

def __exit__(self, exc_type, exc_value, traceback):
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts

super().__exit__(exc_type, exc_value, traceback)
Qwen3_5MoeExperts.forward = self._original_experts_forward


# Legacy patcher kept for backward compatibility — only patches MoE experts, no GatedDeltaNet
class Qwen3_5MoeModelPatcher(OVDecoderModelPatcher):
def __enter__(self):
super().__enter__()
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts
self.original_experts_forward = Qwen3_5MoeExperts.forward
Qwen3_5MoeExperts.forward = qwen3_5_moe_experts_forward_patched

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts
Qwen3_5MoeExperts.forward = self.original_experts_forward

Loading