diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 7ffe158396..b0c5dd16f7 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -148,6 +148,7 @@ FluxTransfromerModelPatcher, Gemma2ModelPatcher, Gemma3LMModelPatcher, + Gemma3nLMModelPatcher, GptJModelPatcher, GptNeoModelPatcher, GptNeoxModelPatcher, @@ -201,6 +202,8 @@ SanaTextEncoderModelPatcher, XverseModelPatcher, Zamba2ModelPatcher, + Gemma3nPerLayerInputsGetterModelPatcher, + Gemma3nImageEmbeddingsModelPatcher, ) @@ -261,6 +264,10 @@ def init_model_configs(): "transformers", "Gemma3ForConditionalGeneration", ) + TasksManager._CUSTOM_CLASSES[("pt", "gemma3n", "image-text-to-text")] = ( + "transformers", + "Gemma3nForConditionalGeneration", + ) TasksManager._CUSTOM_CLASSES[("pt", "idefics3", "image-text-to-text")] = ( "transformers", "AutoModelForImageTextToText", @@ -1480,6 +1487,105 @@ class Gemma3TextOpenVINOConfig(Gemma2OpenVINOConfig): MIN_TRANSFORMERS_VERSION = "4.50.0" +class Gemma3nDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.num_key_value_heads = normalized_config.num_key_value_heads + self.head_dim = normalized_config.head_dim + self.layer_types = normalized_config.config.layer_types + self.num_kv_shared_layers = normalized_config.config.num_kv_shared_layers + self.sliding_window = normalized_config.config.sliding_window + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + # some layers do not produce their own KV-cache, they use the shared KV-cache + if self.num_kv_shared_layers > 0: + layer_types = self.layer_types[: -self.num_kv_shared_layers] + else: + layer_types = self.layer_types + past_kv_values = [] + for layer_type in layer_types: + if layer_type == "sliding_attention": + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sliding_window, + self.head_dim, + ) + else: + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.head_dim, + ) + past_kv_value = ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + past_kv_values.append(past_kv_value) + + return past_kv_values + + +@register_in_tasks_manager( + "gemma3n_text", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class Gemma3nTextOpenVINOConfig(Gemma3TextOpenVINOConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Gemma3nDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = Gemma3nDummyPastKeyValuesGenerator + MIN_TRANSFORMERS_VERSION = "4.50.0" + + def add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + sequence_length" + name = "present" + + num_kv_shared_layers = self._normalized_config.config.num_kv_shared_layers + if num_kv_shared_layers > 0: + layer_types = self._normalized_config.config.layer_types[:-num_kv_shared_layers] + else: + layer_types = self._normalized_config.config.layer_types + + for i, layer_type in enumerate(layer_types): + if layer_type == "sliding_attention": + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + else: + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + + class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def __init__( self, @@ -1721,6 +1827,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs["token_type_ids"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[ 0 ].random_int_tensor(token_type_ids_shape, min_value=0, max_value=2) + if "per_layer_inputs" in self.inputs: + per_layer_inputs_shape = (input_ids.shape[0], input_ids.shape[1], 30, 256) + dummy_inputs["per_layer_inputs"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[ + 0 + ].random_int_tensor(per_layer_inputs_shape, min_value=0, max_value=2) return dummy_inputs @@ -4170,6 +4281,162 @@ def with_behavior( return super().with_behavior(behavior) +class Gemma3nConfigBehavior(str, enum.Enum): + VISION_EMBEDDINGS = "vision_embeddings" + TEXT_EMBEDDINGS = "text_embeddings" + LANGUAGE = "language" + TEXT_EMBEDDINGS_PER_LAYER = "text_embeddings_per_layer" + + +@register_in_tasks_manager("gemma3n", *["image-text-to-text"], library_name="transformers") +class Gemma3nOpenVINOConfig(Gemma3OpenVINOConfig): + SUPPORTED_BEHAVIORS = [model_type.value for model_type in Gemma3nConfigBehavior] + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyTextInputGenerator) + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: Gemma3nConfigBehavior = Gemma3nConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + behavior=behavior, + ) + self._behavior = behavior + + + def with_behavior(self, behavior: Union[str, Gemma3nConfigBehavior]): + """ + Creates a config for different behaviour specific to Gemma3n. + + For LANGUAGE behavior, this explicitly uses the Gemma3n text model_type + instead of relying on the underlying text_config.model_type value. + """ + if isinstance(behavior, str) and not isinstance(behavior, Gemma3nConfigBehavior): + behavior = Gemma3nConfigBehavior(behavior) + + if behavior == Gemma3nConfigBehavior.LANGUAGE: + # Force the Gemma3n-specific text model type to ensure proper behavior + model_type = "gemma3n_text" + return get_vlm_text_generation_config( + model_type, + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Gemma3nLMModelPatcher, + inputs_update={"per_layer_inputs": {0: "batch_size", 1: "sequence_length", 2: "num_hidden_layers"}}, + ) + if behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + config = self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + return config + return super().with_behavior(behavior) + + def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior]): + if behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + import torch + class PerLayerInputsModule(torch.nn.Module): + def __init__(self, language_model, vocab_size_per_layer_input: int): + super().__init__() + self.language_model = language_model + self.vocab_size_per_layer_input = vocab_size_per_layer_input + + def forward(self, input_ids: torch.Tensor): + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input + ) + + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, + input_ids, + torch.zeros_like(input_ids) + ) + + per_layer_inputs = self.language_model.get_per_layer_inputs( + per_layer_inputs_tokens + ) + + return per_layer_inputs + + model = PerLayerInputsModule(model.language_model, model.config.text_config.vocab_size_per_layer_input) + return model + # + # if behavior == VLMConfigBehavior.LANGUAGE: + # return model.language_model + if behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return model + + if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS: + import torch + class TextEmbeddingsModule(torch.nn.Module): + def __init__(self, model): #, vocab_size_per_layer_input: int): + super().__init__() + self.model = model + # self.vocab_size_per_layer_input = vocab_size_per_layer_input + + def forward(self, input_ids: torch.Tensor): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + vision_mask = torch.logical_and( + input_ids >= self.model.model.embed_vision.vocab_offset, input_ids < self.model.model.embed_audio.vocab_offset + ) + dummy_vision_token_id = self.model.model.embed_vision.vocab_offset + self.model.model.embed_vision.vocab_size - 1 + vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to( + inputs_embeds.device) + vision_embeds = self.model.model.embed_vision(input_ids=vision_input_ids) + expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds) + + return inputs_embeds + + text_embedding = TextEmbeddingsModule(model) + text_embedding.config = model.language_model.config + return text_embedding + + return super().get_model_for_behavior(model, behavior) + + def patch_model_for_export(self, model: Union["PreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None): + model_kwargs = model_kwargs or {} + if self._behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + return ModelPatcher(self, model, model_kwargs) + if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return Gemma3nImageEmbeddingsModelPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + + if self._behavior == Gemma3nConfigBehavior.LANGUAGE: + inputs = super().inputs + return inputs + if self._behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } + return super().inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Gemma3nConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + return {"text_embeds_per_layer": {}} + return super().outputs + + class DummyVisionPositionIdsInputGenerator(DummyVisionInputGenerator): SUPPORTED_INPUT_NAMES = ("patch_attention_mask", "patch_position_ids") diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 557cd1f8d1..f21e45692d 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -4745,6 +4745,378 @@ def __exit__(self, exc_type, exc_value, traceback): del self._model.model._orig_update_causual_mask +def _project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, +) -> torch.Tensor: + per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = ( + self.per_layer_projection_scale.to(dtype=inputs_embeds.dtype, device=per_layer_projection.device) + * per_layer_projection + ) + + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + if per_layer_projection.shape != per_layer_inputs.shape: + # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. + per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) + + +def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device) + + def normal_icdf_approx(p): + p = torch.clamp(p, 1e-7, 1 - 1e-7) + a1 = -3.969683028665376e01 + a2 = 2.209460984245205e02 + a3 = -2.759285104469687e02 + a4 = 1.383577518672690e02 + a5 = -3.066479806614716e01 + a6 = 2.506628277459239e00 + b1 = -5.447609879822406e01 + b2 = 1.615858368580409e02 + b3 = -1.556989798598866e02 + b4 = 6.680131188771972e01 + b5 = -1.328068155288572e01 + q = p - 0.5 + r = q * q + num = (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q + den = ((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1.0 + return num / den + + std_multiplier = normal_icdf_approx(target_sparsity_tensor) + std_multiplier = std_multiplier.type(inputs.dtype) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return nn.functional.relu(inputs - cutoff_x) + + +def gemma3n_language_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text inputs + pixel_values: Optional[torch.FloatTensor] = None, # vision inputs + input_features: Optional[torch.FloatTensor] = None, # audio inputs + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + per_layer_inputs = None, + **lm_kwargs, +): + from transformers.models.gemma3n.modeling_gemma3n import Gemma3nModelOutputWithPast + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # Merge text and audio + if input_features is not None and input_features_mask is not None: + audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask) + + # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the + # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens + # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad + # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device) + audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) + audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features) + + audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape + extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len + extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) + + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + + outputs = self.language_model( + input_ids=None, + per_layer_inputs=per_layer_inputs, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **lm_kwargs, + ) + + return Gemma3nModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + audio_hidden_states=audio_features if input_features is not None else None, + ) + + +def gemma3n_lm_forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + per_layer_inputs = None, + input_ids: Optional[torch.LongTensor] = None, # text inputs + pixel_values: Optional[torch.FloatTensor] = None, # vision inputs + input_features: Optional[torch.FloatTensor] = None, # audio inputs + input_features_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, +): + from transformers.models.gemma3n.modeling_gemma3n import Gemma3nCausalLMOutputWithPast + from optimum.exporters.onnx.model_patcher import preprocess_past_key_values, postprocess_past_key_values + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = False + + if past_key_values is not None: + use_cache = True + num_atten_layers = len(past_key_values) + past_key_values = preprocess_past_key_values(past_key_values) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + per_layer_inputs=per_layer_inputs, + **lm_kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + tmp_logits = self.lm_head(hidden_states[:, slice_indices, :]) + if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: + tmp_logits = tmp_logits / final_logit_softcapping + tmp_logits = torch.tanh(tmp_logits) + tmp_logits = tmp_logits * final_logit_softcapping + + outputs_dict = { + "logits": tmp_logits, + } + + if use_cache: + key_values = outputs.past_key_values + present_key_values = postprocess_past_key_values(key_values, ['logits', 'present.0.key', 'present.0.value', 'present.1.key', 'present.1.value', 'present.2.key', 'present.2.value', 'present.3.key', 'present.3.value', 'present.4.key', 'present.4.value', 'present.5.key', 'present.5.value', 'present.6.key', 'present.6.value', 'present.7.key', 'present.7.value', 'present.8.key', 'present.8.value', 'present.9.key', 'present.9.value', 'present.10.key', 'present.10.value', 'present.11.key', 'present.11.value', 'present.12.key', 'present.12.value', 'present.13.key', 'present.13.value', 'present.14.key', 'present.14.value', 'present.15.key', 'present.15.value', 'present.16.key', 'present.16.value', 'present.17.key', 'present.17.value', 'present.18.key', 'present.18.value', 'present.19.key', 'present.19.value']) + outputs_dict["past_key_values"] = present_key_values + return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs_dict.values()]) + + +def gemma3n_eager_attention_forward_patched( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + eps = 0.01 + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states + eps) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def gemma3n_text_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + from collections.abc import Callable + from transformers.models.gemma3n.modeling_gemma3n import apply_rotary_pos_emb as apply_rotary_pos_emb_gemma3n + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.config.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb_gemma3n(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb_gemma3n(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_proj(hidden_states).view(hidden_shape) + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = key_states, value_states + + attention_interface: Callable = gemma3n_eager_attention_forward_patched + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=1.0, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3nLMModelPatcher(Gemma3LMModelPatcher): + def __init__(self, config, model, model_kwargs): + super().__init__(config, model, model_kwargs) + + self.patched_forward = gemma3n_lm_forward + self.model_orig_forward = self.orig_forward + self.orig_forward = gemma3n_lm_forward + + self.model_orig_language_model_forward = self._model.model.forward + + def __enter__(self): + super().__enter__() + + setattr(self._model, self.orig_forward_name, types.MethodType(gemma3n_lm_forward, self._model)) + setattr(self._model.model, "forward", types.MethodType(gemma3n_language_model_forward, self._model)) + + self._model.model.language_model._orig_project_per_layer_inputs = ( + self._model.model.language_model.project_per_layer_inputs + ) + self._model.model.language_model.project_per_layer_inputs = types.MethodType( + _project_per_layer_inputs, self._model.model.language_model + ) + + for decoder_layer in self._model.model.language_model.layers: + decoder_layer.mlp._orig_gaussian_topk = decoder_layer.mlp._gaussian_topk + decoder_layer.mlp._gaussian_topk = types.MethodType(_gaussian_topk, decoder_layer.mlp) + decoder_layer.self_attn.orig_forward = decoder_layer.self_attn.forward + decoder_layer.self_attn.forward = types.MethodType(gemma3n_text_forward, decoder_layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.model.language_model.project_per_layer_inputs = ( + self._model.model.language_model._orig_project_per_layer_inputs + ) + + for decoder_layer in self._model.model.language_model.layers: + decoder_layer.mlp._gaussian_topk = decoder_layer.mlp._orig_gaussian_topk + decoder_layer.self_attn.forward = decoder_layer.self_attn.orig_forward + + setattr(self._model, self.orig_forward_name, self.model_orig_forward) + setattr(self._model.model, "forward", self.model_orig_language_model_forward) + + class Idefics3ImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, @@ -7899,3 +8271,78 @@ def forward( hidden_states=outputs.hidden_states, d2t=d2t_out, ) + + +class Gemma3nPerLayerInputsGetterModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel"], + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + def per_layer_inputs_forward( + self, input_ids: torch.Tensor) -> torch.Tensor: + per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + return per_layer_inputs + + model.forward = types.MethodType(per_layer_inputs_forward, model) + super().__init__(config, model, model_kwargs) + + def __enter__(self): + super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + +# Patched forward for MobileNetV5MultiScaleFusionAdapter (MSFA) used by the Gemma3n vision tower. +# The original MSFA forward has data-dependent control flow that branches on tensor spatial +# dimensions to choose between F.interpolate and F.avg_pool2d for resizing to output_resolution. +# torch.jit.trace bakes only the path taken with dummy inputs, causing incorrect results when +# actual inference images have different spatial dimensions. +# This patch replaces the conditional resize with F.adaptive_avg_pool2d which: +# - accepts a constant output_size making it trace-friendly +# - is equivalent to avg_pool2d when input dims are evenly divisible by output dims +# - handles identity (no-op) when input dims == output dims +# Adopted from timm.models.mobilenetv5.MobileNetV5MultiScaleFusionAdapter.forward +def _gemma3n_msfa_forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + high_resolution = inputs[0].shape[-2:] + resized_inputs = [] + for img in inputs: + feat_size = img.shape[-2:] + if feat_size[0] < high_resolution[0] or feat_size[1] < high_resolution[1]: + img = F.interpolate(img, size=high_resolution, mode=self.interpolation_mode) + resized_inputs.append(img) + + channel_cat_imgs = torch.cat(resized_inputs, dim=1) + img = self.ffn(channel_cat_imgs) + + # Use adaptive_avg_pool2d instead of conditional avg_pool2d / interpolate. + # output_resolution is a constant tuple, so this is trace-friendly. + img = F.adaptive_avg_pool2d(img, self.output_resolution) + + img = self.norm(img) + return img + + +class Gemma3nImageEmbeddingsModelPatcher(CommonImageEmbeddingsModelPatcher): + def __enter__(self): + super().__enter__() + # Patch MSFA forward to be trace-friendly + vision_tower = self._model.model.vision_tower + timm_model = vision_tower.timm_model + if hasattr(timm_model, "msfa") and timm_model.msfa is not None: + timm_model.msfa._orig_forward = timm_model.msfa.forward + timm_model.msfa.forward = types.MethodType(_gemma3n_msfa_forward, timm_model.msfa) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + vision_tower = self._model.model.vision_tower + timm_model = vision_tower.timm_model + if hasattr(timm_model, "msfa") and timm_model.msfa is not None and hasattr(timm_model.msfa, "_orig_forward"): + timm_model.msfa.forward = timm_model.msfa._orig_forward \ No newline at end of file diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 3d9a854e39..5d47d4b1d5 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -297,6 +297,7 @@ def get_submodels(model): "qwen3_vl", "got_ocr2", "gemma3", + "gemma3n", "idefics3", "smolvlm", "phi4mm", diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 2fe8cb0ea0..a8fcf06fbe 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -218,6 +218,18 @@ def prepare_inputs( inputs["beam_idx"] = ( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) ) + + if "per_layer_inputs" in self.input_names: + per_layer_inputs = kwargs.pop("per_layer_inputs", None) + assert per_layer_inputs is not None, "Expected 'per_layer_inputs', but it was not passed" + inputs["per_layer_inputs"] = torch.Tensor(per_layer_inputs) + # if deepstack_visual_embeds is not None: + # inputs["deepstack_visual_embeds"] = torch.Tensor(deepstack_visual_embeds) + # else: + # num_layers = len(self.config.vision_config.deepstack_visual_indexes) + # emd_dim = self.config.text_config.hidden_size + # inputs["deepstack_visual_embeds"] = torch.zeros((num_layers, 1, emd_dim), dtype=torch.float32) + return inputs def forward( @@ -347,6 +359,7 @@ def forward(self, audio_feature, audio_mask): MODEL_PARTS_CLS_MAPPING = { "resampler": OVResampler, "language_model": OVModelWithEmbedForCausalLM, + "text_embeddings_per_layer": OVVisionProjection, "vision_embeddings": OVVisionEmbedding, "vision_projection": OVVisionProjection, "vision_resampler": OVVisionResampler, @@ -785,8 +798,11 @@ def forward( additional_kwargs["visual_pos_masks"] = extra_outputs[0] additional_kwargs["deepstack_visual_embeds"] = extra_outputs[1] - return self.language_model.forward( - input_ids=None, + if self.config.model_type in ("gemma3n",) and extra_outputs: + additional_kwargs["per_layer_inputs"] = extra_outputs[0] + + out = self.language_model.forward( + input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, @@ -796,6 +812,10 @@ def forward( **kwargs, ) + + + return out + def _reorder_cache(self, past_key_values, beam_idx): return self.language_model._reorder_cache(past_key_values, beam_idx) @@ -3868,7 +3888,10 @@ def merge_vision_text_embeddings( self.get_text_embeddings(torch.tensor([[self.config.image_token_index]], dtype=torch.long))[0] ) else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + if self.config.model_type == "gemma3n": + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds) image_features = image_features.to(inputs_embeds.dtype) @@ -3937,6 +3960,32 @@ def _update_model_kwargs_for_generation( return model_kwargs +class _OVGemma3NForCausalLM(_OVGemma3ForCausalLM): + additional_parts = ["text_embeddings_per_layer"] + + def get_multimodal_embeddings( + self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, **kwargs + ): + embeds_from_args = kwargs.pop("inputs_embeds", None) + inputs_embeds = ( + embeds_from_args if embeds_from_args is not None else self.get_text_embeddings(input_ids, **kwargs) + ) + per_layer_inputs = self.text_embeddings_per_layer(input_ids) + if pixel_values is not None: + vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs) + + if vision_embeds is not None: + inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings( + vision_embeds, + inputs_embeds, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + return inputs_embeds, attention_mask, position_ids, per_layer_inputs + + class _OVGotOCR2ForCausalLM(OVModelForVisualCausalLM): def get_vision_embeddings(self, pixel_values, input_ids, **kwargs): if input_ids is not None and input_ids.shape[1] == 1 and kwargs.get("past_key_values") is not None: @@ -4817,6 +4866,7 @@ def preprocess_inputs( "qwen2_5_vl_text": _OVQwen2_5_VLForCausalLM, "got_ocr2": _OVGotOCR2ForCausalLM, "gemma3": _OVGemma3ForCausalLM, + "gemma3n": _OVGemma3NForCausalLM, "idefics3": _OVIdefics3ForCausalLM, "smolvlm": _OVSmolVLForCasualLM, "phi4mm": _OVPhi4MMForCausalLM,