Skip to content
Draft
267 changes: 267 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
Gemma3LMModelPatcher,
Gemma3nLMModelPatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxModelPatcher,
Expand Down Expand Up @@ -201,6 +202,8 @@
SanaTextEncoderModelPatcher,
XverseModelPatcher,
Zamba2ModelPatcher,
Gemma3nPerLayerInputsGetterModelPatcher,
Gemma3nImageEmbeddingsModelPatcher,
)


Expand Down Expand Up @@ -261,6 +264,10 @@ def init_model_configs():
"transformers",
"Gemma3ForConditionalGeneration",
)
TasksManager._CUSTOM_CLASSES[("pt", "gemma3n", "image-text-to-text")] = (
"transformers",
"Gemma3nForConditionalGeneration",
)
TasksManager._CUSTOM_CLASSES[("pt", "idefics3", "image-text-to-text")] = (
"transformers",
"AutoModelForImageTextToText",
Expand Down Expand Up @@ -1480,6 +1487,105 @@ class Gemma3TextOpenVINOConfig(Gemma2OpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.50.0"


class Gemma3nDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
self.head_dim = normalized_config.head_dim
self.layer_types = normalized_config.config.layer_types
self.num_kv_shared_layers = normalized_config.config.num_kv_shared_layers
self.sliding_window = normalized_config.config.sliding_window

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
# some layers do not produce their own KV-cache, they use the shared KV-cache
if self.num_kv_shared_layers > 0:
layer_types = self.layer_types[: -self.num_kv_shared_layers]
else:
layer_types = self.layer_types
past_kv_values = []
for layer_type in layer_types:
if layer_type == "sliding_attention":
shape = (
self.batch_size,
self.num_key_value_heads,
self.sliding_window,
self.head_dim,
)
else:
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.head_dim,
)
past_kv_value = (
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
past_kv_values.append(past_kv_value)

return past_kv_values


@register_in_tasks_manager(
"gemma3n_text",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class Gemma3nTextOpenVINOConfig(Gemma3TextOpenVINOConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Gemma3nDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Gemma3nDummyPastKeyValuesGenerator
MIN_TRANSFORMERS_VERSION = "4.50.0"

def add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

num_kv_shared_layers = self._normalized_config.config.num_kv_shared_layers
if num_kv_shared_layers > 0:
layer_types = self._normalized_config.config.layer_types[:-num_kv_shared_layers]
else:
layer_types = self._normalized_config.config.layer_types

for i, layer_type in enumerate(layer_types):
if layer_type == "sliding_attention":
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}
else:
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}


class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
Expand Down Expand Up @@ -1721,6 +1827,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs["token_type_ids"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[
0
].random_int_tensor(token_type_ids_shape, min_value=0, max_value=2)
if "per_layer_inputs" in self.inputs:
per_layer_inputs_shape = (input_ids.shape[0], input_ids.shape[1], 30, 256)
dummy_inputs["per_layer_inputs"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[
0
].random_int_tensor(per_layer_inputs_shape, min_value=0, max_value=2)
return dummy_inputs


Expand Down Expand Up @@ -4170,6 +4281,162 @@ def with_behavior(
return super().with_behavior(behavior)


class Gemma3nConfigBehavior(str, enum.Enum):
VISION_EMBEDDINGS = "vision_embeddings"
TEXT_EMBEDDINGS = "text_embeddings"
LANGUAGE = "language"
TEXT_EMBEDDINGS_PER_LAYER = "text_embeddings_per_layer"


@register_in_tasks_manager("gemma3n", *["image-text-to-text"], library_name="transformers")
class Gemma3nOpenVINOConfig(Gemma3OpenVINOConfig):
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Gemma3nConfigBehavior]
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyTextInputGenerator)
def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: Gemma3nConfigBehavior = Gemma3nConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
behavior=behavior,
)
self._behavior = behavior


def with_behavior(self, behavior: Union[str, Gemma3nConfigBehavior]):
"""
Creates a config for different behaviour specific to Gemma3n.

For LANGUAGE behavior, this explicitly uses the Gemma3n text model_type
instead of relying on the underlying text_config.model_type value.
"""
if isinstance(behavior, str) and not isinstance(behavior, Gemma3nConfigBehavior):
behavior = Gemma3nConfigBehavior(behavior)

if behavior == Gemma3nConfigBehavior.LANGUAGE:
# Force the Gemma3n-specific text model type to ensure proper behavior
model_type = "gemma3n_text"
return get_vlm_text_generation_config(
model_type,
self._orig_config.text_config,
self.int_dtype,
self.float_dtype,
model_patcher=Gemma3nLMModelPatcher,
inputs_update={"per_layer_inputs": {0: "batch_size", 1: "sequence_length", 2: "num_hidden_layers"}},
)
if behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER:
config = self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)
return config
return super().with_behavior(behavior)

def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior]):
if behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER:
import torch
class PerLayerInputsModule(torch.nn.Module):
def __init__(self, language_model, vocab_size_per_layer_input: int):
super().__init__()
self.language_model = language_model
self.vocab_size_per_layer_input = vocab_size_per_layer_input

def forward(self, input_ids: torch.Tensor):
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input
)

per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask,
input_ids,
torch.zeros_like(input_ids)
)

per_layer_inputs = self.language_model.get_per_layer_inputs(
per_layer_inputs_tokens
)

return per_layer_inputs

model = PerLayerInputsModule(model.language_model, model.config.text_config.vocab_size_per_layer_input)
return model
#
# if behavior == VLMConfigBehavior.LANGUAGE:
# return model.language_model
if behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return model

if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS:
import torch
class TextEmbeddingsModule(torch.nn.Module):
def __init__(self, model): #, vocab_size_per_layer_input: int):
super().__init__()
self.model = model
# self.vocab_size_per_layer_input = vocab_size_per_layer_input

def forward(self, input_ids: torch.Tensor):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
vision_mask = torch.logical_and(
input_ids >= self.model.model.embed_vision.vocab_offset, input_ids < self.model.model.embed_audio.vocab_offset
)
dummy_vision_token_id = self.model.model.embed_vision.vocab_offset + self.model.model.embed_vision.vocab_size - 1
vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(
inputs_embeds.device)
vision_embeds = self.model.model.embed_vision(input_ids=vision_input_ids)
expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)

return inputs_embeds

text_embedding = TextEmbeddingsModule(model)
text_embedding.config = model.language_model.config
return text_embedding

return super().get_model_for_behavior(model, behavior)

def patch_model_for_export(self, model: Union["PreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None):
model_kwargs = model_kwargs or {}
if self._behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER:
return ModelPatcher(self, model, model_kwargs)
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return Gemma3nImageEmbeddingsModelPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)



@property
def inputs(self) -> Dict[str, Dict[int, str]]:

if self._behavior == Gemma3nConfigBehavior.LANGUAGE:
inputs = super().inputs
return inputs
if self._behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}
return super().inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER:
return {"text_embeds_per_layer": {}}
return super().outputs


class DummyVisionPositionIdsInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("patch_attention_mask", "patch_position_ids")

Expand Down
Loading