Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@
Qwen2MoEPatcher,
Qwen2VLLanguageModelPatcher,
Qwen2VLVisionEmbMergerPatcher,
Qwen3_5Patcher,
Qwen3MoeModelPatcher,
Qwen3VLLanguageModelPatcher,
Qwen3VLVisionEmbMergerPatcher,
Expand Down Expand Up @@ -4961,6 +4962,195 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return common_inputs


class Qwen3_5DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
"""
Generates dummy cache_params inputs for Qwen3.5 hybrid GatedDeltaNet + Attention architectures.
Linear attention layers produce conv_states and recurrent_states (fixed size).
Full attention layers produce standard KV cache (variable size).
"""

SUPPORTED_INPUT_NAMES = ("cache_params",)

def __init__(
self,
task: str,
normalized_config,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
**kwargs,
)

config = normalized_config.config
self._model_config = config
# Derive attention layer indices from layer_types list
layer_types = config.layer_types
self.attention_layers_indices = set(
i for i, lt in enumerate(layer_types) if lt == "full_attention"
)
self.num_hidden_layers = config.num_hidden_layers
self.num_linear_layers = self.num_hidden_layers - len(self.attention_layers_indices)
self.num_attention_layers = len(self.attention_layers_indices)

# Linear attention (GatedDeltaNet) state dimensions
self.linear_num_key_heads = config.linear_num_key_heads
self.linear_key_head_dim = config.linear_key_head_dim
self.linear_value_head_dim = config.linear_value_head_dim
self.linear_num_value_heads = config.linear_num_value_heads
self.linear_conv_kernel_dim = config.linear_conv_kernel_dim
# conv_dim = key_dim * 2 + value_dim
self.conv_dim = (
self.linear_num_key_heads * self.linear_key_head_dim * 2
+ self.linear_num_value_heads * self.linear_value_head_dim
)

# Full attention KV cache dimensions
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.sequence_length = 0

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_values = []

# Linear attention layers: conv_states + recurrent_states
for i in range(self.num_linear_layers):
conv_state_shape = (self.batch_size, self.conv_dim, self.linear_conv_kernel_dim)
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
past_key_values.append(conv_state)

recurrent_state_shape = (
self.batch_size,
self.linear_num_key_heads,
self.linear_key_head_dim,
self.linear_value_head_dim,
)
recurrent_state = self.random_float_tensor(recurrent_state_shape, framework=framework, dtype=float_dtype)
past_key_values.append(recurrent_state)

# Full attention layers: key + value cache
for i in range(self.num_attention_layers):
kv_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.head_dim)
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
past_key_values.append(k)
past_key_values.append(v)

return past_key_values


@register_in_tasks_manager(
"qwen3_5",
*[
"text-generation",
"text-generation-with-past",
],
library_name="transformers",
)
@register_in_tasks_manager(
"qwen3_5_text",
*[
"text-generation",
"text-generation-with-past",
],
library_name="transformers",
)
class Qwen3_5OpenVINOConfig(MambaOpenVINOConfig):
PAD_ATTENTION_MASK_TO_PAST = False
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Qwen3_5DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Qwen3_5DummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
# No attention_mask input: patched_forward passes None to avoid hardcoded reshapes.
NO_ATTENTION_MASK = True
MIN_TRANSFORMERS_VERSION = "5.3.0"
_MODEL_PATCHER = Qwen3_5Patcher

def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None):
model_kwargs = model_kwargs or {}
return Qwen3_5Patcher(self, model, model_kwargs=model_kwargs)

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
cache_name_prefix = "cache_params.past"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
cache_name_prefix = "cache_params.present"

config = self._normalized_config.config
layer_types = config.layer_types
num_hidden_layers = config.num_hidden_layers

# Grouped order: all linear attention layers first, then all full attention layers.
# This must match the order in Qwen3_5DummyPastKeyValuesGenerator.generate()
# and Qwen3_5Patcher.patched_forward() cache unpacking/repacking.
linear_layer_idx = 0
for i in range(num_hidden_layers):
if layer_types[i] == "linear_attention":
inputs_or_outputs[f"{cache_name_prefix}.conv.{linear_layer_idx}"] = {0: "batch_size"}
inputs_or_outputs[f"{cache_name_prefix}.recurrent.{linear_layer_idx}"] = {0: "batch_size"}
linear_layer_idx += 1

attention_layer_idx = 0
for i in range(num_hidden_layers):
if layer_types[i] == "full_attention":
inputs_or_outputs[f"{cache_name_prefix}.key.{attention_layer_idx}"] = {
0: "batch_size",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{cache_name_prefix}.value.{attention_layer_idx}"] = {
0: "batch_size",
2: decoder_sequence_name,
}
attention_layer_idx += 1

def overwrite_shape_and_generate_input(self, dummy_input_gen, input_name, framework, input_shapes):
# Qwen3.5's GatedDeltaNet has separate prefill (seq_len > 1) and decode (seq_len == 1) paths.
# The stateful model must trace the decode path so that conv/recurrent cache inputs are consumed.
# Force seq_len=1 for input_ids when past states are present.
if self.use_past and self.use_past_in_inputs and input_name in ("input_ids", "position_ids"):
saved = dummy_input_gen.sequence_length
dummy_input_gen.sequence_length = 1
result = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
dummy_input_gen.sequence_length = saved
return result
if self.use_past and self.use_past_in_inputs and input_name == "attention_mask":
# attention_mask must be LONGER than input_ids (length > 1) during tracing so that
# torch.jit.trace captures the padding_mask slicing branch in sdpa_mask():
# if padding_mask.shape[-1] > kv_length: padding_mask = padding_mask[:, -kv_length:]
# This makes the graph correctly adapt to growing attention_mask at runtime.
# apply_mask_to_padding_states is patched to no-op by Qwen3_5Patcher to avoid
# the broadcast issue with hidden_states * attention_mask[:, :, None].
import torch
return torch.ones(dummy_input_gen.batch_size, 2, dtype=torch.int64)
return super().overwrite_shape_and_generate_input(dummy_input_gen, input_name, framework, input_shapes)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
# attention_mask is NOT included: patched_forward passes attention_mask=None
# to avoid hardcoded reshapes from torch.jit.trace's causal mask computation.
# The model creates a pure causal mask depending only on KV cache shape (dynamic).
# position_ids IS included: needed for correct RoPE in full_attention layers.
# Without it, cache_position (baked to [0]) would give every token position 0.
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"position_ids": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs


@register_in_tasks_manager("audio-spectrogram-transformer", *["feature-extraction", "audio-classification"])
class ASTOpenVINOConfig(ASTOnnxConfig):
pass
Expand Down
Loading