Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4668,6 +4668,84 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
return dummy_inputs


class HybridCacheOpenVINOConfig(TextDecoderOnnxConfig):
"""
Base config for hybrid models that use cache_params with both recurrent/conv and attention states.
Handles attention_mask dynamic axis and padding for stateful KV-cache inference.

Subclasses must define:
_NON_KV_LAYER_TYPES: tuple of layer type names for non-KV-cache layers
_KV_LAYER_TYPES: tuple of layer type names for KV-cache layers
_NON_KV_ENTRIES_PER_LAYER: number of cache entries per non-KV layer (default: 2)
"""

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
_NON_KV_LAYER_TYPES = ()
_KV_LAYER_TYPES = ()
_NON_KV_ENTRIES_PER_LAYER = 2

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["attention_mask"] = {0: "batch_size", 1: "sequence_length"}
return common_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")]
if self.use_past_in_inputs:
input_names.extend(["cache_params"])

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

if self.use_past_in_inputs and "attention_mask" in dummy_inputs and "cache_params" in dummy_inputs:
# Pad attention_mask to cover past KV-cache sequence length + current input length.
# This ensures the exported graph builds a rectangular causal mask
# (seq_len x total_seq_len) instead of a square one (seq_len x seq_len),
# which is required for stateful KV-cache inference where attention_mask
# is longer than input_ids.
layer_types = self._normalized_config.layer_types
num_non_kv_layers = sum(layer_types.count(lt) for lt in self._NON_KV_LAYER_TYPES)
num_kv_layers = sum(layer_types.count(lt) for lt in self._KV_LAYER_TYPES)
if num_kv_layers > 0:
kv_cache_offset = self._NON_KV_ENTRIES_PER_LAYER * num_non_kv_layers
past_key = dummy_inputs["cache_params"][kv_cache_offset]
past_seq_len = past_key.shape[2]
if past_seq_len > 0:
past_present_length = dummy_inputs["input_ids"].shape[1] + past_seq_len
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_present_length,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)

return dummy_inputs


@register_in_tasks_manager(
"gpt2",
*[
Expand Down Expand Up @@ -4879,7 +4957,10 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
],
library_name="transformers",
)
class LFM2OpenVINOConfig(MambaOpenVINOConfig):
class LFM2OpenVINOConfig(HybridCacheOpenVINOConfig):
_NON_KV_LAYER_TYPES = ("conv",)
_KV_LAYER_TYPES = ("full_attention",)
_NON_KV_ENTRIES_PER_LAYER = 1
MIN_TRANSFORMERS_VERSION = "4.54.0"
_MODEL_PATCHER = Lfm2ModelPatcher

Expand Down Expand Up @@ -4907,24 +4988,15 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs


@register_in_tasks_manager(
"granitemoehybrid", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
class GraniteMoeHybridOpenVINOConfig(MambaOpenVINOConfig):
class GraniteMoeHybridOpenVINOConfig(HybridCacheOpenVINOConfig):
_NON_KV_LAYER_TYPES = ("mamba",)
_KV_LAYER_TYPES = ("attention",)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Zamba2DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Zamba2DummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = "4.53.0"
_MODEL_PATCHER = GraniteMoeHybridModelPatcher

Expand All @@ -4951,16 +5023,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs


@register_in_tasks_manager("audio-spectrogram-transformer", *["feature-extraction", "audio-classification"])
class ASTOpenVINOConfig(ASTOnnxConfig):
Expand Down Expand Up @@ -5384,10 +5446,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3NextOpenVINOConfig(Qwen3OpenVINOConfig):
class Qwen3NextOpenVINOConfig(HybridCacheOpenVINOConfig):
_NON_KV_LAYER_TYPES = ("linear_attention",)
_KV_LAYER_TYPES = ("full_attention",)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Qwen3NextDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Qwen3NextDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = "4.57.0"
_MODEL_PATCHER = Qwen3NextModelPatcher

Expand All @@ -5412,42 +5475,3 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
for i in range(self.num_full_attn_layers):
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# need to override `generate_dummy_inputs` since mamba model has other states: ssm_states and conv_states
# which we separate and call them as past_ssm_states and past_conv_states
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")]
if self.use_past_in_inputs:
input_names.extend(["cache_params"])

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

return dummy_inputs
Loading