diff --git a/docs/source/en/model_doc/gemma3.md b/docs/source/en/model_doc/gemma3.md new file mode 100644 index 000000000000..0f466bb55bf2 --- /dev/null +++ b/docs/source/en/model_doc/gemma3.md @@ -0,0 +1,203 @@ + + + +# Gemma3 + +## Overview + +The Gemma 3 model was proposed in the [Gemma 3 Techncial Report](https://goo.gle/Gemma3Report) by Google. It is a vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma 2](gemma_2) language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer. + +This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins), [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) [Arthur Zucker](https://huggingface.co/ArthurZ), and [Pedro Cuenca](https://huggingface.co/pcuenq). + + +## Usage tips + + +- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`. +- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower. +- Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images. +- The text passed to the processor should have a `` token wherever an image should be inserted. +- The processor has its own `apply_chat_template` method to convert chat messages to model inputs. See the examples below for more details on how to use it. + + +### Image cropping for high resolution images + +The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images. + +Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc. + +```python + +processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left") + +url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" +messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", "content": [ + {"type": "image", "url": url}, + {"type": "text", "text": "What is shown in this image?"}, + ] + }, +] +inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + do_pan_and_scan=True, +).to(model.device) + +``` + + +## Usage Example + +### Single-image Inference + +```python +from transformers import AutoProcessor, Gemma3ForConditionalGeneration + +model_id = "google/gemma-3-4b-it" +model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id, padding_side="left") + +url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" +messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", "content": [ + {"type": "image", "url": url}, + {"type": "text", "text": "What is shown in this image?"}, + ] + }, +] +inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, +).to(model.device) + +output = model.generate(**inputs, max_new_tokens=50) +print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ]) +``` + +### Multi-image Inference + +```python +model_id = "google/gemma-3-4b-it" +model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id, padding_side="left") + +url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" +url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg" +messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", "content": [ + {"type": "image", "url": url_cow}, + {"type": "image", "url": url_stop}, + {"type": "text", "text": "Are these two images identical?"}, + ] + }, +] +inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, +).to(model.device) + +output = model.generate(**inputs, max_new_tokens=50) +print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ]) + +``` + +### Text-only inference + +You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities. +```python +from transformers import AutoTokenizer, Gemma3ForCausalLM + +model_id = "google/gemma-3-1b-it" + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto") + +input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device) + +outputs = model.generate(**input_ids, max_new_tokens=100) +text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + +print(text) + +``` + + +## Gemma3ImageProcessor + +[[autodoc]] Gemma3ImageProcessor + +## Gemma3ImageProcessorFast + +[[autodoc]] Gemma3ImageProcessorFast + +## Gemma3Processor + +[[autodoc]] Gemma3Processor + +## Gemma3TextConfig + +[[autodoc]] Gemma3TextConfig + +## Gemma3Config + +[[autodoc]] Gemma3Config + +## Gemma3TextModel + +[[autodoc]] Gemma3TextModel + - forward + +## Gemma3ForCausalLM + +[[autodoc]] Gemma3ForCausalLM + - forward + +## Gemma3ForConditionalGeneration + +[[autodoc]] Gemma3ForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 12ac2c60d75f..9ce1fe1378bf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -474,6 +474,7 @@ "models.fuyu": ["FuyuConfig"], "models.gemma": ["GemmaConfig"], "models.gemma2": ["Gemma2Config"], + "models.gemma3": ["Gemma3Config", "Gemma3Processor", "Gemma3TextConfig"], "models.git": [ "GitConfig", "GitProcessor", @@ -1259,6 +1260,7 @@ _import_structure["models.emu3"].append("Emu3ImageProcessor") _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"]) _import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"]) + _import_structure["models.gemma3"].append("Gemma3ImageProcessor") _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"]) _import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"]) _import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"]) @@ -1332,6 +1334,7 @@ _import_structure["models.deit"].append("DeiTImageProcessorFast") _import_structure["models.depth_pro"].append("DepthProImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast") + _import_structure["models.gemma3"].append("Gemma3ImageProcessorFast") _import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast") _import_structure["models.llava"].append("LlavaImageProcessorFast") _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") @@ -2452,6 +2455,14 @@ "Gemma2PreTrainedModel", ] ) + _import_structure["models.gemma3"].extend( + [ + "Gemma3ForCausalLM", + "Gemma3ForConditionalGeneration", + "Gemma3PreTrainedModel", + "Gemma3TextModel", + ] + ) _import_structure["models.git"].extend( [ "GitForCausalLM", @@ -2554,6 +2565,7 @@ "GraniteMoePreTrainedModel", ] ) + _import_structure["models.granitemoeshared"].extend( [ "GraniteMoeSharedForCausalLM", @@ -2561,7 +2573,6 @@ "GraniteMoeSharedPreTrainedModel", ] ) - _import_structure["models.grounding_dino"].extend( [ "GroundingDinoForObjectDetection", @@ -5629,6 +5640,7 @@ from .models.fuyu import FuyuConfig from .models.gemma import GemmaConfig from .models.gemma2 import Gemma2Config + from .models.gemma3 import Gemma3Config, Gemma3Processor, Gemma3TextConfig from .models.git import ( GitConfig, GitProcessor, @@ -6450,6 +6462,7 @@ FlavaProcessor, ) from .models.fuyu import FuyuImageProcessor, FuyuProcessor + from .models.gemma3 import Gemma3ImageProcessor from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor from .models.got_ocr2 import GotOcr2ImageProcessor from .models.grounding_dino import GroundingDinoImageProcessor @@ -6535,6 +6548,7 @@ from .models.deit import DeiTImageProcessorFast from .models.depth_pro import DepthProImageProcessorFast from .models.detr import DetrImageProcessorFast + from .models.gemma3 import Gemma3ImageProcessorFast from .models.got_ocr2 import GotOcr2ImageProcessorFast from .models.llava import LlavaImageProcessorFast from .models.llava_next import LlavaNextImageProcessorFast @@ -7461,6 +7475,12 @@ Gemma2Model, Gemma2PreTrainedModel, ) + from .models.gemma3 import ( + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3PreTrainedModel, + Gemma3TextModel, + ) from .models.git import ( GitForCausalLM, GitModel, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 33c87bb35bea..93b083d7bed5 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -113,10 +113,10 @@ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]: sp = self.sp vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())} - # there is a missing token in the vocab. We have to do this to support merges + # If "\t" is missing in the vocab, we have to do this to support merges # "<0x09>" is the bytefallback for `\t` - vocab["\t"] = vocab.get("<0x09>") - + if "\t" not in vocab: + vocab["\t"] = vocab.get("<0x09>") merges = generate_merges(vocab, vocab_scores) return vocab, merges @@ -1296,12 +1296,14 @@ def vocab(self, proto): (self.original_tokenizer.eos_token, 0.0), (self.original_tokenizer.bos_token, 0.0), ] - for piece in proto.pieces[3:]: - if piece.piece == "<0x09>": - vocab += [("\t", piece.score)] - else: - vocab += [(piece.piece, piece.score)] - # vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + + # Older gemma tokenizers had a missing tab token, so we fix that here + if not any(x[0] == "\t" for x in vocab): + override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None) + if override_index is not None: + vocab[override_index] = ("\t", 0.0) + return vocab def pre_tokenizer(self, replacement, add_prefix_space): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 29e55e01be83..d172cdd22ea9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -821,13 +821,13 @@ def _load_state_dict_into_meta_model( is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") for serialized_param_name, empty_param in state_dict.items(): + if serialized_param_name not in expected_keys: + continue + # serialized_param_name is the raw, serialized name # fixed_param_name is the model's equivalent fixed_param_name, _ = model.rename_key(serialized_param_name) - if fixed_param_name not in expected_keys: - continue - # we need to use serialized_param_name as file pointer is untouched param = ( file_pointer.get_slice(serialized_param_name) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3884daabd973..c30b97ade727 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -106,6 +106,7 @@ fuyu, gemma, gemma2, + gemma3, git, glm, glpn, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fa4de1955430..3c6b849d8c40 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -124,6 +124,8 @@ ("fuyu", "FuyuConfig"), ("gemma", "GemmaConfig"), ("gemma2", "Gemma2Config"), + ("gemma3", "Gemma3Config"), + ("gemma3_text", "Gemma3TextConfig"), ("git", "GitConfig"), ("glm", "GlmConfig"), ("glpn", "GLPNConfig"), @@ -459,6 +461,8 @@ ("fuyu", "Fuyu"), ("gemma", "Gemma"), ("gemma2", "Gemma2"), + ("gemma3", "Gemma3ForConditionalGeneration"), + ("gemma3_text", "Gemma3ForCausalLM"), ("git", "GIT"), ("glm", "GLM"), ("glpn", "GLPN"), @@ -748,6 +752,7 @@ ("qwen2_audio_encoder", "qwen2_audio"), ("clip_text_model", "clip"), ("aria_text", "aria"), + ("gemma3_text", "gemma3"), ("idefics3_vision", "idefics3"), ("siglip_vision_model", "siglip"), ("smolvlm_vision", "smolvlm"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 180d156359b2..fedf1070e046 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -86,6 +86,7 @@ ("flava", ("FlavaImageProcessor",)), ("focalnet", ("BitImageProcessor",)), ("fuyu", ("FuyuImageProcessor",)), + ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("glpn", ("GLPNImageProcessor",)), ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d9fd502c1fae..aa0d120b7f57 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -118,6 +118,7 @@ ("funnel", ("FunnelModel", "FunnelBaseModel")), ("gemma", "GemmaModel"), ("gemma2", "Gemma2Model"), + ("gemma3_text", "Gemma3TextModel"), ("git", "GitModel"), ("glm", "GlmModel"), ("glpn", "GLPNModel"), @@ -338,6 +339,7 @@ ("fnet", "FNetForPreTraining"), ("fsmt", "FSMTForConditionalGeneration"), ("funnel", "FunnelForPreTraining"), + ("gemma3", "Gemma3ForConditionalGeneration"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), ("gpt_bigcode", "GPTBigCodeForCausalLM"), @@ -518,6 +520,8 @@ ("fuyu", "FuyuForCausalLM"), ("gemma", "GemmaForCausalLM"), ("gemma2", "Gemma2ForCausalLM"), + ("gemma3", "Gemma3ForCausalLM"), + ("gemma3_text", "Gemma3ForCausalLM"), ("git", "GitForCausalLM"), ("glm", "GlmForCausalLM"), ("got_ocr2", "GotOcr2ForConditionalGeneration"), @@ -824,6 +828,7 @@ ("chameleon", "ChameleonForConditionalGeneration"), ("emu3", "Emu3ForConditionalGeneration"), ("fuyu", "FuyuForCausalLM"), + ("gemma3", "Gemma3ForConditionalGeneration"), ("git", "GitForCausalLM"), ("got_ocr2", "GotOcr2ForConditionalGeneration"), ("idefics", "IdeficsForVisionText2Text"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 2d6da5ac13b4..d29d3f8d1a1b 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -63,6 +63,7 @@ ("emu3", "Emu3Processor"), ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), + ("gemma3", "Gemma3Processor"), ("git", "GitProcessor"), ("got_ocr2", "GotOcr2Processor"), ("grounding-dino", "GroundingDinoProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 57bcd31296cc..cb3e921f8ea2 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -215,6 +215,13 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "gemma3", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 07cfc30f4ac8..34a0deb13bdf 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -41,7 +41,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -936,42 +935,23 @@ def prepare_inputs_for_generation( ): # Overwritten: has a special cache type, `HybridCache` - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s - # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride - # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the - # batch size = 1 case, `position_ids` is already contiguous but with varying stride - # which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing # (retrieving the same value from `cache_position` later on would crash dynamo) model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if logits_to_keep is None: + _ = model_inputs.pop("logits_to_keep", None) if ( isinstance(past_key_values, HybridCache) @@ -994,19 +974,8 @@ def prepare_inputs_for_generation( cache_position=cache_position, batch_size=batch_size, ) + model_inputs["attention_mask"] = attention_mask - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) return model_inputs diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index cd3ae3ed0efc..0f32c00287e7 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,7 +29,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import is_torchdynamo_compiling, logging +from ...utils import logging from ..gemma.modeling_gemma import ( GemmaAttention, GemmaForCausalLM, @@ -686,42 +686,23 @@ def prepare_inputs_for_generation( ): # Overwritten: has a special cache type, `HybridCache` - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. - # (we can't check exception 3 while compiling) - if past_key_values is not None: - if ( - inputs_embeds is not None # Exception 1 - or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 - ): - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s - # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride - # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the - # batch size = 1 case, `position_ids` is already contiguous but with varying stride - # which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing # (retrieving the same value from `cache_position` later on would crash dynamo) model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if logits_to_keep is None: + _ = model_inputs.pop("logits_to_keep", None) if ( isinstance(past_key_values, HybridCache) @@ -744,19 +725,8 @@ def prepare_inputs_for_generation( cache_position=cache_position, batch_size=batch_size, ) + model_inputs["attention_mask"] = attention_mask - if logits_to_keep is not None: - model_inputs["logits_to_keep"] = logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) return model_inputs diff --git a/src/transformers/models/gemma3/__init__.py b/src/transformers/models/gemma3/__init__.py new file mode 100644 index 000000000000..37ec82f91037 --- /dev/null +++ b/src/transformers/models/gemma3/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gemma3 import * + from .image_processing_gemma3 import * + from .image_processing_gemma3_fast import * + from .modeling_gemma3 import * + from .processing_gemma3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py new file mode 100644 index 000000000000..c19a05ba60c4 --- /dev/null +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -0,0 +1,330 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging +from ..siglip import SiglipVisionConfig + + +logger = logging.get_logger(__name__) + + +class Gemma3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma3Text-7B. + e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3TextModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + + ```python + >>> from transformers import Gemma3TextModel, Gemma3TextConfig + >>> # Initializing a Gemma3Text gemma3_text-7b style configuration + >>> configuration = Gemma3TextConfig() + >>> # Initializing a model from the gemma3_text-7b style configuration + >>> model = Gemma3TextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + """ + + model_type = "gemma3_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=262_208, + hidden_size=2304, + intermediate_size=9216, + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=131_072, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=1_000_000.0, + attention_bias=False, + attention_dropout=0.0, + query_pre_attn_scalar=256, + sliding_window=4096, + final_logit_softcapping=None, + attn_logit_softcapping=None, + cache_implementation="hybrid", + rope_scaling=None, + rope_local_base_freq=10_000.0, + sliding_window_pattern=6, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.cache_implementation = cache_implementation + + self.rope_local_base_freq = rope_local_base_freq + # For configuring HybridCache to work with 5:1 attention pattern + self.sliding_window_pattern = sliding_window_pattern + self.rope_scaling = rope_scaling + rope_config_validation(self) + + +class Gemma3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an + Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[Gemma3TextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a Gemma3 Text config + >>> text_config = Gemma3TextConfig() + + >>> # Initializing a Gemma3 gemma-3-4b style configuration + >>> configuration = Gemma3Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3" + sub_configs = { + "text_config": Gemma3TextConfig, + "vision_config": SiglipVisionConfig, + } + + def __init__( + self, + text_config: Optional[Gemma3TextConfig] = None, + vision_config: Optional[SiglipVisionConfig] = None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if text_config is None: + text_config = Gemma3TextConfig() + logger.info("text_config is None, using default Gemma3TextConfig vision config.") + elif isinstance(text_config, dict): + text_config = Gemma3TextConfig(**text_config) + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + else: + vision_config = SiglipVisionConfig() + logger.info( + "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " + "to text tasks." + ) + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["Gemma3Config", "Gemma3TextConfig"] diff --git a/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py b/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py new file mode 100644 index 000000000000..28c0192cc78c --- /dev/null +++ b/src/transformers/models/gemma3/convert_gemma3_weights_orbax_to_hf.py @@ -0,0 +1,592 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. + +python -m transformers.models.gemma3.convert_gemma3_weights_orbax_to_hf \ + --variant='gemma3_4b' \ + --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ + --checkpoint_path="$HOME/gemma3/gemma3_4b_pt_orbax/" \ + --output_path="$HOME/gemma3/gemma3_4b_pt_safetensors/" \ + --precision='bfloat16' +""" + +import dataclasses +from collections.abc import Iterator, Sequence +from typing import Any + +import accelerate +import numpy as np +import torch +import tree +from absl import app, flags, logging +from orbax import checkpoint as obc + +from ...image_utils import PILImageResampling +from ..gemma import GemmaTokenizerFast +from . import ( + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3ImageProcessor, + Gemma3Processor, +) +from .configuration_gemma3 import ( + Gemma3Config, + Gemma3TextConfig, + SiglipVisionConfig, +) + + +# ==== Internal Constants and Classes ==== + + +_CHAT_TEMPLATE = """{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model\n'}} +{%- endif -%} +""" + +_DTYPES = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + +_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder" +_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) +_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" + +_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" +_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) +_TRANSFORMER_EMBEDDER = "transformer/embedder" +_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" +_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) + +_VISION_CONFIG = { + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "num_channels": 3, + "image_size": 896, + "patch_size": 14, + "hidden_act": "gelu_pytorch_tanh", + "layer_norm_eps": 1e-6, + "attention_dropout": 0.0, + "vision_use_head": False, +} + +_VARIANT_GEMMA_3_1B = "gemma3_1b" +_VARIANT_GEMMA_3_4B = "gemma3_4b" +_VARIANT_GEMMA_3_12B = "gemma3_12b" +_VARIANT_GEMMA_3_27B = "gemma3_27b" +_VARIANTS = { + _VARIANT_GEMMA_3_1B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_144, + hidden_size=1152, + intermediate_size=6 * 1152, + num_attention_heads=4, + num_hidden_layers=26, + num_key_value_heads=1, + head_dim=256, + sliding_window=512, + rope_theta=1_000_000, # used for global RoPE only + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256, + max_position_embeddings=32_768, + ), + vision_config=None, + ), + _VARIANT_GEMMA_3_4B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_208, + hidden_size=2560, + intermediate_size=2560 * 8 // 2, + num_attention_heads=8, + head_dim=256, + num_hidden_layers=34, + num_key_value_heads=4, + sliding_window=1024, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256, + ), + vision_config=_VISION_CONFIG, + ), + _VARIANT_GEMMA_3_12B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_208, + hidden_size=30 * 128, + intermediate_size=30 * 128 * 8 // 2, + num_attention_heads=16, + head_dim=256, + num_hidden_layers=48, + num_key_value_heads=8, + sliding_window=1024, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256, + ), + vision_config=_VISION_CONFIG, + ), + _VARIANT_GEMMA_3_27B: Gemma3Config( + text_config=Gemma3TextConfig( + vocab_size=262_208, + hidden_size=42 * 128, + intermediate_size=42 * 128 * 8 // 2, + num_attention_heads=32, + num_hidden_layers=62, + num_key_value_heads=16, + head_dim=128, + sliding_window=1024, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=(42 * 128 // 32), # 1 / sqrt(hidden_size // num_attention_heads) + ), + vision_config=_VISION_CONFIG, + ), +} + +# ==== Flags ==== + +CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path", + default=None, + help="Path to the Orbax checkpoint.", + required=True, +) + +INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool( + name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" +) + +OUTPUT_PATH = flags.DEFINE_string( + name="output_path", + default=None, + help="Path to store the HF checkpoint.", + required=True, +) + +PRECISION = flags.DEFINE_enum( + name="precision", + default=None, + help="The floating point precision (aka dtype) of the model.", + enum_values=set(_DTYPES.keys()), + required=True, +) + +_TEXT_ONLY = flags.DEFINE_bool( + name="text_only", + default=False, + help=( + "If True, the model is loaded and saved as a Gemma3ForCausalLM, " + "otherwise model saed as Gemma3ForConditionalGeneration." + ), +) + +TOKENIZER_PATH = flags.DEFINE_string( + name="tokenizer_path", + default=None, + help="Path to the SentencePiece model file.", + required=True, +) + +_VARIANT = flags.DEFINE_enum( + name="variant", + default=_VARIANT_GEMMA_3_4B, + help="The model variant to convert.", + enum_values=set(_VARIANTS.keys()), +) + + +def convert_siglip_weight( + config: SiglipVisionConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> tuple[str, np.ndarray]: + path, prop = paths + normalized_path: str = "" + updated_weights: np.ndarray = None + + if path == _SIGLIP_BASE: + normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight" + updated_weights = weights.reshape(-1, config.hidden_size) + elif path == _SIGLIP_EMBEDDING: + if prop == "kernel": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight" + updated_weights = weights.transpose(3, 2, 0, 1) + elif prop == "bias": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK): + encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:] + next_path_seperator_idx = encoder_block_path.find("/") + layer_idx = encoder_block_path[:next_path_seperator_idx] + encoder_block_path = encoder_block_path[next_path_seperator_idx:] + normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}" + + if encoder_block_path.startswith("/LayerNorm"): + normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2" + + if prop == "scale": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + elif encoder_block_path.startswith("/MlpBlock_0"): + normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2" + + if prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"): + if encoder_block_path.endswith("/key"): + normalized_path += ".self_attn.k_proj" + elif encoder_block_path.endswith("/out"): + normalized_path += ".self_attn.out_proj" + elif encoder_block_path.endswith("/query"): + normalized_path += ".self_attn.q_proj" + elif encoder_block_path.endswith("/value"): + normalized_path += ".self_attn.v_proj" + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.") + + if prop == "bias": + normalized_path += ".bias" + updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1) + elif prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.reshape(-1, config.hidden_size).transpose() + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.") + elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM: + if prop == "scale": + normalized_path = "vision_tower.vision_model.post_layernorm.weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path = "vision_tower.vision_model.post_layernorm.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if "vision" in normalized_path: + print(normalized_path) + return normalized_path, updated_weights + + +def convert_transformer_weights( + config: Gemma3TextConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> Iterator[tuple[str, np.ndarray]]: + path, prop = paths + + if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): + path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] + + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + attn_head_dim = config.num_attention_heads * config.head_dim + kv_head_dim = config.num_key_value_heads * config.head_dim + + if path == _TRANSFORMER_EMBEDDER: + if prop == "input_embedding": + # Tied to language_model.lm_head.weight, assigned at the end. + converted_paths = ["language_model.model.embed_tokens.weight"] + + if not _TEXT_ONLY.value: + # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama + pre_expansion_embeddings = weights + mu = np.mean(pre_expansion_embeddings, axis=0) + sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) + new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) + weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + + converted_weights = [weights] + elif _TEXT_ONLY.value or prop in ("mm_output_embedding", "mm_input_embedding_extra"): + return zip([], []) + else: + raise ValueError(f"Unexpected member, {prop}, in Embedder.") + elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): + if _TEXT_ONLY.value: + return zip([], []) + + if path.endswith("/mm_input_projection"): + converted_paths = ["multi_modal_projector.mm_input_projection_weight"] + converted_weights = [weights] + elif path.endswith("/mm_soft_embedding_norm"): + converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") + elif path == _TRANSFORMER_FINAL_NORM: + converted_paths = ["language_model.model.norm.weight"] + converted_weights = [weights] + elif path.startswith(_TRANSFORMER_DECODER_BLOCK): + decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] + next_path_seperator_idx = decoder_block_path.find("/") + layer_idx = decoder_block_path[:next_path_seperator_idx] + decoder_block_path = decoder_block_path[next_path_seperator_idx:] + + base_path = f"language_model.model.layers.{layer_idx}" + + if path.endswith("attn/attn_vec_einsum"): + converted_paths = [f"{base_path}.self_attn.o_proj.weight"] + converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)] + elif path.endswith("attn/_key_norm"): + converted_paths = [f"{base_path}.self_attn.k_norm.weight"] + converted_weights = [weights] + elif path.endswith("attn/kv_einsum"): + converted_paths = [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + k_proj_weights, v_proj_weights = weights + converted_weights = [ + k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + ] + elif path.endswith("attn/q_einsum"): + converted_paths = [f"{base_path}.self_attn.q_proj.weight"] + converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)] + elif path.endswith("attn/_query_norm"): + converted_paths = [f"{base_path}.self_attn.q_norm.weight"] + converted_weights = [weights] + elif path.endswith("mlp/gating_einsum"): + converted_paths = [ + f"{base_path}.mlp.gate_proj.weight", + f"{base_path}.mlp.up_proj.weight", + ] + gate_proj_weight, up_proj_weight = weights + converted_weights = [gate_proj_weight, up_proj_weight] + elif path.endswith("mlp/linear"): + converted_paths = [f"{base_path}.mlp.down_proj.weight"] + converted_weights = [weights.transpose()] + elif path.endswith("post_attention_norm"): + converted_paths = [f"{base_path}.post_attention_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("post_ffw_norm"): + converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_attention_norm"): + converted_paths = [f"{base_path}.input_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_ffw_norm"): + converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected path `{path}` in Decoder Block.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +@dataclasses.dataclass(frozen=True) +class ConversionResult: + state_tree: dict[str, torch.Tensor] + config: Gemma3Config + + +def convert( + checkpoint_path: str, + config: Gemma3Config, + target_dtype: torch.dtype, +) -> ConversionResult: + """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" + checkpointer = obc.PyTreeCheckpointer() + ckpt = checkpointer.restore(checkpoint_path) + hf_tree: dict[str, torch.Tensor] = {} + + def update_tree(path: str, weights: np.ndarray) -> None: + torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype) + logging.info( + "%s converted shape=%s with dtype=%s", + path, + weights.shape, + torch_tensor.dtype, + ) + hf_tree[path] = torch_tensor + + for paths, value in tree.flatten_with_path(ckpt): + if paths[0].startswith("SigLiPFromPatches_"): + if config.vision_config is None: + continue + + path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) + update_tree(path, weights) + else: + for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): + if config.vision_config is None: + path = path[len("language_model.") :] + + update_tree(path, weights) + + if config.vision_config is None: + hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"] + else: + hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"] + + return ConversionResult(state_tree=hf_tree, config=config) + + +def main(*args): + del args + + variant = _VARIANT.value + dtype = getattr(torch, PRECISION.value) + config = _VARIANTS[variant] + output_path = OUTPUT_PATH.value + + if variant == _VARIANT_GEMMA_3_1B: + flags.FLAGS.set_default(_TEXT_ONLY.name, True) + + tokenizer = GemmaTokenizerFast( + TOKENIZER_PATH.value, + add_bos_token=True, + extra_special_tokens={ + "image_token": "", # Should be ID=262_144 + "boi_token": "", # Should be ID=255_999 + "eoi_token": "", # Should be ID=256_000 + }, + ) + + if INCLUDE_CHAT_TEMPLATE.value: + # Include chat template for CausalLM models + tokenizer.chat_template = _CHAT_TEMPLATE + config.eos_token_id = [1, 106] + + if _TEXT_ONLY.value: + config.vision_config = None + tokenizer.save_pretrained(output_path) + logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) + del tokenizer + else: + image_processor = Gemma3ImageProcessor( + image_seq_length=256, + image_mean=(0.5,) * 3, + image_std=(0.5,) * 3, + size={"height": 896, "width": 896}, + resample=PILImageResampling.BILINEAR, + ) + processor = Gemma3Processor( + image_processor=image_processor, + tokenizer=tokenizer, + ) + if INCLUDE_CHAT_TEMPLATE.value: + # Duplicate so multimodal instruct models can also be used for CausalLM + processor.chat_template = tokenizer.chat_template + + processor.save_pretrained(output_path) + logging.info("Saved Gemma3Processor for %s to %s", variant, output_path) + del processor + del tokenizer + + logging.info("Gemma 3 (%s) configured as: %s", variant, config) + logging.info("Converting Gemma 3 (%s) @ %s", variant, dtype) + result = convert(CHECKPOINT_PATH.value, config, dtype) + logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) + + with accelerate.init_empty_weights(): + if config.vision_config is None: + model = Gemma3ForCausalLM(config=config.text_config) + else: + model = Gemma3ForConditionalGeneration(config) + + model.load_state_dict(result.state_tree, assign=True, strict=True) + model.config.torch_dtype = dtype + logging.info("Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.", variant, type(model).__name__) + model.save_pretrained(output_path, safe_serialization=True) + logging.info( + "Saved Gemma 3 (%s) to SafeTensors in %s using %s", + variant, + output_path, + type(model).__name__, + ) + del model + del result + + +if __name__ == "__main__": + app.run(main) diff --git a/src/transformers/models/gemma3/image_processing_gemma3.py b/src/transformers/models/gemma3/image_processing_gemma3.py new file mode 100644 index 000000000000..f985a9a9dd80 --- /dev/null +++ b/src/transformers/models/gemma3/image_processing_gemma3.py @@ -0,0 +1,413 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Gemma3.""" + +import itertools +import math +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_nested_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class Gemma3ImageProcessor(BaseImageProcessor): + r""" + Constructs a SigLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image by the specified mean and standard deviation. Can be overridden by + `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """ + + model_input_names = ["pixel_values", "num_crops"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + do_pan_and_scan: bool = None, + pan_and_scan_min_crop_size: int = None, + pan_and_scan_max_num_crops: int = None, + pan_and_scan_min_ratio_to_activate: float = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.do_pan_and_scan = do_pan_and_scan + self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size + self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops + self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate + + def pan_and_scan( + self, + image: np.ndarray, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pan and Scan and image, by cropping into smaller images when the aspect ratio exceeds + minumum allowed ratio. + + Args: + image (`np.ndarray`): + Image to resize. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + height, width = get_image_size(image) + + # Square or landscape image. + if width >= height: + # Only apply PaS if the image is sufficiently exaggerated + if width / height < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. + num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + + # Portrait image. + else: + # Only apply PaS if the image is sufficiently exaggerated + if height / width < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_h = int(math.floor(height / width + 0.5)) + num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(width / num_crops_w)) + crop_size_h = int(math.ceil(height / num_crops_h)) + + # Don't apply PaS if crop size is too small. + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return [] + + crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] + crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] + + if input_data_format == ChannelDimension.LAST: + image_crops = [ + image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] + else: + image_crops = [ + image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] + + return image_crops + + def _process_images_for_pan_and_scan( + self, + images: List[np.ndarray], + do_pan_and_scan: bool, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + pas_images_list = [] + num_crops = [] + for image in images: + pas_images = self.pan_and_scan( + image=image, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + data_format=data_format, + input_data_format=input_data_format, + ) + pas_images_list.extend([image] + pas_images) + num_crops.append(len(pas_images)) + return pas_images_list, num_crops + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: bool = None, + do_pan_and_scan: bool = None, + pan_and_scan_min_crop_size: int = None, + pan_and_scan_max_num_crops: int = None, + pan_and_scan_min_ratio_to_activate: float = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*, defaults to `self.pan_and_scan_min_crop_size`): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*, defaults to `self.pan_and_scan_max_num_crops`): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*, defaults to `self.pan_and_scan_min_ratio_to_activate`): + Minimum aspect ratio to activate pan and scan. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pan_and_scan = do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan + pan_and_scan_min_crop_size = ( + pan_and_scan_min_crop_size if pan_and_scan_min_crop_size is not None else self.pan_and_scan_min_crop_size + ) + pan_and_scan_max_num_crops = ( + pan_and_scan_max_num_crops if pan_and_scan_max_num_crops is not None else self.pan_and_scan_max_num_crops + ) + pan_and_scan_min_ratio_to_activate = ( + pan_and_scan_min_ratio_to_activate + if pan_and_scan_min_ratio_to_activate is not None + else self.pan_and_scan_min_ratio_to_activate + ) + + images_list = make_nested_list_of_images(images) + + if not valid_images(images_list[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if do_convert_rgb: + images_list = [[convert_to_rgb(image) for image in images] for images in images_list] + + # All transformations expect numpy arrays. + images_list = [[to_numpy_array(image) for image in images] for images in images_list] + + if do_rescale and is_scaled_image(images_list[0][0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images_list[0][0]) + + if do_pan_and_scan: + images_list_and_num_crops = [ + self._process_images_for_pan_and_scan( + images=images, + do_pan_and_scan=do_pan_and_scan, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + data_format=data_format, + input_data_format=input_data_format, + ) + for images in images_list + ] + images_list = [images for images, _ in images_list_and_num_crops] + num_crops = [num_crops for _, num_crops in images_list_and_num_crops] + else: + num_crops = [[0] for images in images_list] + + processed_images = [] + for images in images_list: + for image in images: + if do_resize: + height, width = size["height"], size["width"] + image = resize( + image=image, size=(height, width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + data = {"pixel_values": processed_images, "num_crops": num_crops} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Gemma3ImageProcessor"] diff --git a/src/transformers/models/gemma3/image_processing_gemma3_fast.py b/src/transformers/models/gemma3/image_processing_gemma3_fast.py new file mode 100644 index 000000000000..0a26f25231c2 --- /dev/null +++ b/src/transformers/models/gemma3/image_processing_gemma3_fast.py @@ -0,0 +1,387 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SigLIP.""" + +import itertools +import math +from functools import partial +from typing import List, Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorInitKwargs, + DefaultFastImageProcessorPreprocessKwargs, + get_size_dict, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + SizeDict, + get_image_size, + make_nested_list_of_images, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) + + +if is_vision_available(): + from ...image_utils import PILImageResampling + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + +logger = logging.get_logger(__name__) + + +class Gemma3FastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + + +class Gemma3FastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + + +@add_start_docstrings( + "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of Pan adn Scan cropping method.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """, +) +class Gemma3ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = True + do_resize = True + do_rescale = True + do_normalize = True + do_pan_and_scan = None + pan_and_scan_min_crop_size = None + pan_and_scan_max_num_crops = None + pan_and_scan_min_ratio_to_activate = None + valid_init_kwargs = Gemma3FastImageProcessorInitKwargs + valid_preprocess_kwargs = Gemma3FastImageProcessorPreprocessKwargs + + def __init__(self, **kwargs: Unpack[Gemma3FastImageProcessorInitKwargs]): + super().__init__(**kwargs) + + def _prepare_images_structure( + self, + images: ImageInput, + ) -> ImageInput: + """ + Prepare the images structure for processing. + + Args: + images (`ImageInput`): + The input images to process. + + Returns: + `ImageInput`: The images with a valid nesting. + """ + return make_nested_list_of_images(images) + + def _prepare_input_images( + self, + images: ImageInput, + do_convert_rgb: bool = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + device: Optional["torch.device"] = None, + ) -> List["torch.Tensor"]: + """ + Prepare the input images for processing. + """ + batch_images = self._prepare_images_structure(images) + process_image_fn = partial( + self._process_image, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + # todo: yoni - check if we can parallelize this efficiently + batch_processed_images = [] + for image_list in batch_images: + processed_images = [] + for image in image_list: + processed_images.append(process_image_fn(image)) + batch_processed_images.append(processed_images) + + return batch_processed_images + + def pan_and_scan( + self, + image: "torch.Tensor", + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + ): + """ + Pan and Scan an image, by cropping into smaller images when the aspect ratio exceeds + minumum allowed ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """ + height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST) + + # Square or landscape image. + if width >= height: + # Only apply PaS if the image is sufficiently exaggerated + if width / height < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. + num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + + # Portrait image. + else: + # Only apply PaS if the image is sufficiently exaggerated + if height / width < pan_and_scan_min_ratio_to_activate: + return [] + + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + num_crops_h = int(math.floor(height / width + 0.5)) + num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) + + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(width / num_crops_w)) + crop_size_h = int(math.ceil(height / num_crops_h)) + + # Don't apply PaS if crop size is too small. + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return [] + + crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] + crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] + + return [ + image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) + ] + + def _process_images_for_pan_and_scan( + self, + images: List["torch.Tensor"], + do_pan_and_scan: bool, + pan_and_scan_min_crop_size: int, + pan_and_scan_max_num_crops: int, + pan_and_scan_min_ratio_to_activate: float, + ): + pas_images_list = [] + num_crops = [] + for image in images: + pas_images = self.pan_and_scan( + image=image, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + ) + pas_images_list.extend([image] + pas_images) + num_crops.append(len(pas_images)) + return pas_images_list, num_crops + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_pan_and_scan (`bool`, *optional*): + Whether to apply `pan_and_scan` to images. + pan_and_scan_min_crop_size (`int`, *optional*): + Minimum size of each crop in pan and scan. + pan_and_scan_max_num_crops (`int`, *optional*): + Maximum number of crops per image in pan and scan. + pan_and_scan_min_ratio_to_activate (`float`, *optional*): + Minimum aspect ratio to activate pan and scan. + """, + ) + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[Gemma3FastImageProcessorPreprocessKwargs], + ) -> BatchFeature: + validate_kwargs( + captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys() + ) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_preprocess_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Pop kwargs that need further processing or won't be used in _preprocess + default_to_square = kwargs.pop("default_to_square") + size = kwargs.pop("size") + crop_size = kwargs.pop("crop_size") + image_mean = kwargs.pop("image_mean") + image_std = kwargs.pop("image_std") + data_format = kwargs.pop("data_format") + resample = kwargs.pop("resample") + + # Make hashable for cache + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None + crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None + image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean + image_std = tuple(image_std) if isinstance(image_std, list) else image_std + + image_mean, image_std, interpolation = self._prepare_process_arguments( + size=size, + crop_size=crop_size, + resample=resample, + image_mean=image_mean, + image_std=image_std, + data_format=data_format if data_format is not None else ChannelDimension.FIRST, + device=images[0][0].device, + do_resize=kwargs.get("do_resize"), + do_center_crop=kwargs.get("do_center_crop"), + do_rescale=kwargs.get("do_rescale"), + rescale_factor=kwargs.get("rescale_factor"), + do_normalize=kwargs.get("do_normalize"), + return_tensors=kwargs.get("return_tensors"), + ) + + return self._preprocess( + images=images, + size=size, + crop_size=crop_size, + interpolation=interpolation, + image_mean=image_mean, + image_std=image_std, + **kwargs, + ) + + def _preprocess( + self, + images: List[List["torch.Tensor"]], + do_resize: bool, + size: SizeDict, + do_pan_and_scan: Optional[bool], + pan_and_scan_min_crop_size: Optional[int], + pan_and_scan_max_num_crops: Optional[int], + pan_and_scan_min_ratio_to_activate: Optional[float], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + ) -> BatchFeature: + processed_images = [] + batch_num_crops = [] + + for image_list in images: + if do_pan_and_scan: + images_list, num_crops = self._process_images_for_pan_and_scan( + images=image_list, + do_pan_and_scan=do_pan_and_scan, + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, + ) + else: + num_crops = [[0] for images in images_list] + + # Group images by size for batched processing + processed_image_patches_grouped = {} + grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list) + for shape, stacked_image_patches in grouped_image_patches.items(): + if do_resize: + stacked_image_patches = self.resize( + image=stacked_image_patches, + size=size, + interpolation=interpolation, + ) + # Fused rescale and normalize + stacked_image_patches = self.rescale_and_normalize( + stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_image_patches_grouped[shape] = stacked_image_patches + processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index) + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) + processed_images.extend(processed_image_patches) + batch_num_crops.extend(num_crops) + + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature( + data={"pixel_values": processed_images, "num_crops": batch_num_crops}, tensor_type=return_tensors + ) + + +__all__ = ["Gemma3ImageProcessorFast"] diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py new file mode 100644 index 000000000000..d5498c8615db --- /dev/null +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -0,0 +1,1451 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "Gemma3Config" + + +@dataclass +class Gemma3CausalLMOutputWithPast(ModelOutput): + """ + Base class for Gemma3 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3TextScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +class Gemma3MLP(nn.Module): + def __init__(self, config: Gemma3TextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Gemma3RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma3RotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Gemma3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask.to(query_states), + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3DecoderLayer(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_sliding = self.self_attn.is_sliding + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window + else: + min_dtype = torch.finfo(attention_mask.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +GEMMA3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Gemma3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma3 Model outputting raw hidden-states without any specific head on top.", + GEMMA3_START_DOCSTRING, +) +class Gemma3PreTrainedModel(PreTrainedModel): + config_class = Gemma3Config + base_model_prefix = "language_model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Gemma3DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of Gemma2 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +GEMMA3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Gemma3Text Model outputting raw hidden-states without any specific head on top.", + GEMMA3_START_DOCSTRING, +) +class Gemma3TextModel(Gemma3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma3TextDecoderLayer`] + + Args: + config: Gemma3TextConfig + """ + + config_class = Gemma3TextConfig + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma3TextScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 + ) + self.layers = nn.ModuleList( + [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas + # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config = copy.deepcopy(config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None and not self.training: + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + dtype=inputs_embeds.dtype, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings_global = self.rotary_emb(hidden_states, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings_global, + position_embeddings_local, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + @torch.no_grad() + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: HybridCache, + output_attentions: bool, + ): + # Flash Attention currently doesn't support static cache but Gemma3Text work only with static cache. + # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape + # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible + # as it doesn't cause dynamic control issues. + if self.config._attn_implementation == "flash_attention_2": + return attention_mask + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if isinstance(past_key_values, (HybridCache, StaticCache)): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Gemma3TextConfig + base_model_prefix = "language_model" + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.model = Gemma3TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3ForCausalLM + + >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten: has a special cache type, `HybridCache` + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 + if logits_to_keep is None: + _ = model_inputs.pop("logits_to_keep", None) + + if ( + isinstance(past_key_values, HybridCache) + and attention_mask.ndim == 2 + and not self.config._attn_implementation == "flash_attention_2" + ): + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + ) + model_inputs["attention_mask"] = attention_mask + + return model_inputs + + +class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = Gemma3RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@add_start_docstrings( + """The GEMMA3 model which consists of a vision backbone and a language model.""", + GEMMA3_START_DOCSTRING, +) +class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): + def __init__(self, config: Gemma3Config): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModelForCausalLM.from_config(config=config.text_config) + + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + return attention_mask + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device + ) + + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_image_features(self, pixel_values: torch.Tensor): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_index >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs.logits + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Gemma3 are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + def tie_weights(self): + return self.language_model.tie_weights() + + +__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"] diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py new file mode 100644 index 000000000000..2626e958326e --- /dev/null +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -0,0 +1,848 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...cache_utils import Cache, HybridCache, StaticCache +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + ModelOutput, +) +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ( + logging, +) +from ..bart.modeling_bart import BartScaledWordEmbedding +from ..gemma2.configuration_gemma2 import Gemma2Config +from ..gemma2.modeling_gemma2 import ( + Gemma2Attention, + Gemma2ForCausalLM, + Gemma2MLP, + Gemma2Model, + Gemma2PreTrainedModel, + Gemma2RMSNorm, + Gemma2RotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +from ..siglip import SiglipVisionConfig + + +_CHECKPOINT_FOR_DOC = "google/gemma-3-4b" +_CONFIG_FOR_DOC = "Gemma3Config" + +logger = logging.get_logger(__name__) + +GEMMA3_INPUTS_DOCSTRING = "" + + +class Gemma3TextConfig(Gemma2Config): + r""" + This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma3Text-7B. + e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 262208): + Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Gemma3TextModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + Scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): in Gemma3Text, every other layer uses sliding window attention. This is the + size of the sliding window. + final_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*): + Scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + + ```python + >>> from transformers import Gemma3TextModel, Gemma3TextConfig + >>> # Initializing a Gemma3Text gemma3_text-7b style configuration + >>> configuration = Gemma3TextConfig() + >>> # Initializing a model from the gemma3_text-7b style configuration + >>> model = Gemma3TextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + sliding_window_pattern (`int`, *optional*, defaults to 6): + Pattern for the sliding window attention. + """ + + model_type = "gemma3_text" + + def __init__( + self, + vocab_size=262_208, + rope_theta=1_000_000.0, + rope_scaling=None, + rope_local_base_freq=10_000.0, + sliding_window_pattern=6, + max_position_embeddings=131_072, + final_logit_softcapping=None, + attn_logit_softcapping=None, + **super_kwargs, + ): + super().__init__(self, **super_kwargs) + + self.rope_local_base_freq = rope_local_base_freq + # For configuring HybridCache to work with 5:1 attention pattern + self.sliding_window_pattern = sliding_window_pattern + self.rope_scaling = rope_scaling + rope_config_validation(self) + + +class Gemma3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an + Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[Gemma3TextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a Gemma3 Text config + >>> text_config = Gemma3TextConfig() + + >>> # Initializing a Gemma3 gemma-3-4b style configuration + >>> configuration = Gemma3Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3" + sub_configs = { + "text_config": Gemma3TextConfig, + "vision_config": SiglipVisionConfig, + } + + def __init__( + self, + text_config: Optional[Gemma3TextConfig] = None, + vision_config: Optional[SiglipVisionConfig] = None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if text_config is None: + text_config = Gemma3TextConfig() + logger.info("text_config is None, using default Gemma3TextConfig vision config.") + elif isinstance(text_config, dict): + text_config = Gemma3TextConfig(**text_config) + + if isinstance(vision_config, dict): + vision_config = SiglipVisionConfig(**vision_config) + else: + vision_config = SiglipVisionConfig() + logger.info( + "vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited " + "to text tasks." + ) + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +@dataclass +class Gemma3CausalLMOutputWithPast(ModelOutput): + """ + Base class for Gemma3 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3TextScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +class Gemma3MLP(Gemma2MLP): + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + + +class Gemma3RMSNorm(Gemma2RMSNorm): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + + +class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): + def __init__(self, config: Gemma3TextConfig, device=None): + super().__init__(config) + + +# Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding` +class Gemma3Attention(Gemma2Attention): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + + super().__init__() + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask.to(query_states), + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3DecoderLayer(nn.Module): + def __init__(self, config: Gemma3TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma3MLP(config) + self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.is_sliding = self.self_attn.is_sliding + self.sliding_window = config.sliding_window + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: int = 0, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding + # In prefill, we may be larger than sliding window + effective_seq_len = max(cache_position.shape[0], self.sliding_window) + # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), + # thus we must slice from the right (at most `effective_seq_len` elements) + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask[:, -effective_seq_len:] + # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice + # from the left, with an offset if we are beyond the sliding window + else: + min_dtype = torch.finfo(attention_mask.dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window + ) + attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) + # In case we are beyond the sliding window, we need to correctly offset the mask slicing + # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo + offset = last_cache_position - effective_seq_len + # Should only be used when beyond the sliding window (i.e. offset > 0) + offset = max(0, offset) + attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +GEMMA3_START_DOCSTRING = None + + +class Gemma3PreTrainedModel(Gemma2PreTrainedModel): + base_model_prefix = "language_model" + _no_split_modules = [ + "Gemma3DecoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + + def _init_weights(self, module): + # important: this ported version of Gemma2 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Gemma3TextModel(Gemma2Model): + config_class = Gemma3TextConfig + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + + # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma3TextScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 + ) + + # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas + # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config = copy.deepcopy(config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + last_cache_position: Optional[int] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None and not self.training: + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + dtype=inputs_embeds.dtype, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing + # (retrieving the same value from `cache_position` later on would crash dynamo) + if last_cache_position is None: + last_cache_position = 0 + if attention_mask is not None: + # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position + # It will break dynamo tracing but there are no way around it (and it should never happen in practice) + last_cache_position = ( + attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings_global = self.rotary_emb(hidden_states, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_embeddings_global, + position_embeddings_local, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + last_cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + last_cache_position=last_cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class Gemma3ForCausalLM(Gemma2ForCausalLM): + config_class = Gemma3TextConfig + base_model_prefix = "language_model" + + def __init__(self, config: Gemma3TextConfig): + super().__init__(config) + self.model = Gemma3TextModel(config) + + +class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = Gemma3RMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): + def tie_weights(self): + return self.language_model.tie_weights() + + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.multi_modal_projector(vision_outputs) + return image_features + + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + return attention_mask + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device + ) + + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = [ + "Gemma3Config", + "Gemma3TextConfig", + "Gemma3PreTrainedModel", # noqa: F822 + "Gemma3TextModel", + "Gemma3ForCausalLM", + "Gemma3ForConditionalGeneration", +] diff --git a/src/transformers/models/gemma3/processing_gemma3.py b/src/transformers/models/gemma3/processing_gemma3.py new file mode 100644 index 000000000000..e82b609bdb10 --- /dev/null +++ b/src/transformers/models/gemma3/processing_gemma3.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import to_py_obj + + +class Gemma3ImagesKwargs(ImagesKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + do_convert_rgb: Optional[bool] + + +class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Gemma3ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "do_pan_and_scan": False, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + }, + } + + +class Gemma3Processor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "image_seq_length"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + image_seq_length: int = 256, + **kwargs, + ): + self.image_seq_length = image_seq_length + self.image_token_id = tokenizer.image_token_id + self.boi_token = tokenizer.boi_token + image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) + self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" + + super().__init__( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=chat_template, + **kwargs, + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos=None, + audio=None, + **kwargs: Unpack[Gemma3ProcessorKwargs], + ) -> BatchFeature: + if text is None and images is None: + raise ValueError("Provide at least one of `text` or `images`.") + + output_kwargs = self._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + image_inputs = {} + if images is not None: + batched_images = make_nested_list_of_images(images) + image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) + + # Create empty text to be replaced with placeholders + if not text: + text = [" ".join([self.boi_token] * len(images)) for images in batched_images] + + if len(batched_images) != len(text): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + # Replace image tokens by the full expanded sequence + batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) + text_with_crops = text + for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)): + image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] + + if len(images) != len(image_indexes): + raise ValueError( + f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." + ) + + # Insert additional image tokens for Pan-and-Scan crops + for num, idx in reversed(list(zip(num_crops, image_indexes))): + if num: + formatted_image_text = ( + f"Here is the original image {self.boi_token} and here are some crops to help you see better " + + " ".join([self.boi_token] * num) + ) + prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :] + text_with_crops[batch_idx] = prompt + + # Expand placeholder image tokens to the full image token sequence + text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") + + # Add token type ids manually, as tokenizer can't do arbitrary position token types + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs + text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Gemma3Processor"] diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 2bf456047d9f..a9aebb8aefe7 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -477,11 +477,6 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -490,8 +485,16 @@ def forward( is_training = token_type_ids is not None and labels is not None + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_index >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = self.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -506,10 +509,16 @@ def forward( if pixel_values is not None: image_features = self.get_image_features(pixel_values) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] raise ValueError( f"Number of images does not match number of special image tokens in the input text. " f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5b43469abe5f..76bcf6fc1c30 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4609,6 +4609,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Gemma3ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma3ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma3PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Gemma3TextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GitForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index de62c4ae7cb1..23a55f33b045 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -58,6 +58,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class Gemma3ImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class GotOcr2ImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 92906e005f90..f3594e1ed087 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -289,6 +289,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Gemma3ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class GLPNFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3a4171161f50..00ac78a94d3f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -124,6 +124,7 @@ "qwen2vl", "qwen2_5_vl", "ayavision", + "gemma3", ] diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index ffadf3377e0a..0b4abb85e051 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -353,7 +353,6 @@ def test_model_various_embeddings(self): def test_Gemma_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - print(config) config.num_labels = 3 input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index c881ecaea559..e384db8423a7 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -153,6 +153,13 @@ def test_generate_continue_from_inputs_embeds(self): def test_sdpa_equivalence(self): pass + @unittest.skip( + reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`" + " as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + @slow @require_torch_gpu diff --git a/tests/models/gemma3/__init__.py b/tests/models/gemma3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/gemma3/test_image_processing_gemma3.py b/tests/models/gemma3/test_image_processing_gemma3.py new file mode 100644 index 000000000000..4a26ff87adb5 --- /dev/null +++ b/tests/models/gemma3/test_image_processing_gemma3.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import Gemma3ImageProcessor + + if is_torchvision_available(): + from transformers import Gemma3ImageProcessorFast + + +class Gemma3ImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_normalize=True, + image_mean=IMAGENET_STANDARD_MEAN, + image_std=IMAGENET_STANDARD_STD, + do_convert_rgb=True, + do_pan_and_scan=True, + pan_and_scan_min_crop_size=10, + pan_and_scan_max_num_crops=2, + pan_and_scan_min_ratio_to_activate=1.2, + ): + super().__init__() + size = size if size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.do_pan_and_scan = do_pan_and_scan + self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size + self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops + self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + "do_pan_and_scan": self.do_pan_and_scan, + "pan_and_scan_min_crop_size": self.pan_and_scan_min_crop_size, + "pan_and_scan_max_num_crops": self.pan_and_scan_max_num_crops, + "pan_and_scan_min_ratio_to_activate": self.pan_and_scan_min_ratio_to_activate, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.size["height"], self.size["width"] + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = Gemma3ImageProcessor if is_vision_available() else None + fast_image_processing_class = Gemma3ImageProcessorFast if is_torchvision_available() else None + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Gemma3 + def setUp(self): + super().setUp() + self.image_processor_tester = Gemma3ImageProcessingTester(self) + + @property + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "do_pan_and_scan")) + self.assertTrue(hasattr(image_processing, "pan_and_scan_min_crop_size")) + self.assertTrue(hasattr(image_processing, "pan_and_scan_max_num_crops")) + self.assertTrue(hasattr(image_processing, "pan_and_scan_min_ratio_to_activate")) + + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=84) + self.assertEqual(image_processor.size, {"height": 84, "width": 84}) + + def test_pan_and_scan(self): + """ + Enables Pan and Scan path by choosing the correct input image resolution. If you are changing + image processor attributes for PaS, please update this test. + """ + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + """This function prepares a list of PIL images""" + image_inputs = [np.random.randint(255, size=(3, 300, 600), dtype=np.uint8)] * 3 + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + # Test not batched input, 3 images because we have base image + 2 crops + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (3, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched, 9 images because we have base image + 2 crops per each item + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (9, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pil(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pytorch(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + @unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method") + def test_call_numpy_4_channels(self): + pass diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py new file mode 100644 index 000000000000..6787f5ad5856 --- /dev/null +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Gemma3 model.""" + +import unittest + +from parameterized import parameterized + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Gemma3Config, + Gemma3TextConfig, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...models.gemma.test_modeling_gemma import GemmaModelTester +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + Gemma3ForCausalLM, + Gemma3ForConditionalGeneration, + Gemma3Processor, + Gemma3TextModel, + ) + + +class Gemma3ModelTester(GemmaModelTester): + if is_torch_available(): + config_class = Gemma3TextConfig + model_class = Gemma3TextModel + for_causal_lm_class = Gemma3ForCausalLM + + +@require_torch +class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Gemma3TextModel, Gemma3ForCausalLM) if is_torch_available() else () + all_generative_model_classes = (Gemma3ForCausalLM,) if is_torch_available() else () + test_headmasking = False + test_pruning = False + _is_stateful = True + model_split_percents = [0.5, 0.6] + + def setUp(self): + self.model_tester = Gemma3ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37) + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation") + def test_beam_search_low_memory(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_with_static_cache(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_continue_from_inputs_embeds(self): + pass + + @unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.") + def test_eager_matches_fa2_generate(self): + pass + + @unittest.skip( + reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`" + " as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + + +class Gemma3Vision2TextModelTester: + def __init__( + self, + parent, + mm_tokens_per_image=2, + image_token_index=1, + boi_token_index=2, + eoi_token_index=3, + seq_length=25, + is_training=True, + vision_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + use_cache=False, + ): + self.parent = parent + # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify + self.mm_tokens_per_image = mm_tokens_per_image + self.image_token_index = image_token_index + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.llm_tester = Gemma3ModelTester(self.parent) + self.text_config = self.llm_tester.get_config() + self.vision_config = vision_config + self.seq_length = seq_length + self.pad_token_id = self.text_config.pad_token_id + + self.num_hidden_layers = self.text_config.num_hidden_layers + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + self.num_attention_heads = self.text_config.num_attention_heads + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.encoder_seq_length = seq_length + self.use_cache = use_cache + + def get_config(self): + return Gemma3Config( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_index=self.image_token_index, + boi_token_index=self.boi_token_index, + eoi_token_index=self.eoi_token_index, + mm_tokens_per_image=self.mm_tokens_per_image, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) + + # set the 3 first tokens to be image, and ensure that no other tokens are image tokens + # do not change this unless you modified image size or patch size + input_ids[input_ids == config.image_token_index] = self.pad_token_id + input_ids[:, :1] = config.image_token_index + + token_type_ids = torch.zeros_like(input_ids) + token_type_ids[input_ids == config.image_token_index] = 1 + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + return config, inputs_dict + + +@require_torch +class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else () + test_headmasking = False + test_pruning = False + test_missing_keys = False + _is_stateful = True + model_split_percents = [0.5, 0.6] + + # MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + + def setUp(self): + self.model_tester = Gemma3Vision2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37) + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip( + reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`" + " as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation") + def test_beam_search_low_memory(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_with_static_cache(self): + pass + + @unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip( + reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation" + ) + def test_initialization(self): + pass + + @unittest.skip( + reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan" + ) + def test_flex_attention_with_grads(self): + pass + + +@slow +@require_torch_gpu +# @require_read_token +class Gemma3IntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = Gemma3Processor.from_pretrained("gg-hf-g/gemma-3-4b-it", padding_side="left") + + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + self.messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": url}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def test_model_4b_bf16(self): + model_id = "gg-hf-g/gemma-3-4b-it" + + model = Gemma3ForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_4b_batch(self): + model_id = "gg-hf-g/gemma-3-4b-it" + + model = Gemma3ForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + messages_2 = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + }, + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Are these images identical?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + [self.messages, messages_2], + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" + ] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_4b_multiimage(self): + model_id = "gg-hf-g/gemma-3-4b-it" + + model = Gemma3ForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "What do you see here?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_1b_text_only(self): + model_id = "gg-hf-g/gemma-3-1b-it" + + model = Gemma3ForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + # TODO: raushan FA2 generates gibberish for no reason, check later + # @require_flash_attn + # @require_torch_gpu + # @mark.flash_attn_test + # def test_model_4b_flash_attn(self): + # model_id = "gg-hf-g/gemma-3-4b-it" + # + # model = Gemma3ForConditionalGeneration.from_pretrained( + # model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + # ).to(torch_device) + # + # inputs = self.processor.apply_chat_template( + # self.messages, + # tokenize=True, + # return_dict=True, + # return_tensors="pt", + # add_generation_prompt=True, + # ).to(torch_device) + # + # output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + # output_text = self.processor.batch_decode(output, skip_special_tokens=True) + # + # EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nPlease look out that you are what Grammy and Vi- ||.xfairesr--ith alerts themselves are||ِّ\n\n**General Note:**'] # fmt: skip + # self.assertEqual(output_text, EXPECTED_TEXTS) + + @parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)]) + def test_generation_beyond_sliding_window(self, attn_implementation: str): + """Test that we can correctly generate beyond the sliding window. This is non trivial as + we need to correctly slice the attention mask in all cases (because we use a HybridCache). + Outputs for every attention functions should be coherent and identical. + """ + model_id = "gg-hf-g/gemma-3-1b-it" + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + out = model.generate(**inputs, max_new_tokens=20)[:, input_size:] + output_text = tokenizer.batch_decode(out) + + EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + self.assertEqual(output_text, EXPECTED_COMPLETIONS) diff --git a/tests/models/gemma3/test_processing_gemma3.py b/tests/models/gemma3/test_processing_gemma3.py new file mode 100644 index 000000000000..e72ca9c2bf19 --- /dev/null +++ b/tests/models/gemma3/test_processing_gemma3.py @@ -0,0 +1,136 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest +from typing import Optional + +from transformers import Gemma3Processor, GemmaTokenizer +from transformers.testing_utils import get_tests_dir, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import Gemma3ImageProcessor + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_vision +class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Gemma3Processor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + gemma3_image_processor_kwargs = { + "do_pan_and_scan": True, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + } + image_processor = Gemma3ImageProcessor.from_pretrained( + "google/siglip-so400m-patch14-384", **gemma3_image_processor_kwargs + ) + + extra_special_tokens = { + "image_token": "", + "boi_token": "", + "eoi_token": "", + } + tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens) + processor_kwargs = self.prepare_processor_dict() + processor = Gemma3Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs) + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + # TODO: raushan or arthur: add the real chat template + def prepare_processor_dict(self): + return { + "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", "image_seq_length": 3, + } # fmt: skip + + # Override as VLMs need image tokens in prompts + def prepare_text_inputs(self, batch_size: Optional[int] = None): + if batch_size is None: + return "lower newer " + + if batch_size < 1: + raise ValueError("batch_size must be greater than 0") + + if batch_size == 1: + return ["lower newer "] + return ["lower newer ", " upper older longer string"] + [ + " lower newer" + ] * (batch_size - 2) + + # Override as Gemma3 needs images to be an explicitly nested batch + def prepare_image_inputs(self, batch_size: Optional[int] = None): + """This function prepares a list of PIL images for testing""" + images = super().prepare_image_inputs(batch_size) + if isinstance(images, (list, tuple)): + images = [[image] for image in images] + return images + + def test_text_with_image_tokens(self): + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + text_multi_images = f"{processor.boi_token}{processor.boi_token}Dummy text!" + text_single_image = f"{processor.boi_token}Dummy text!" + text_no_image = "Dummy text!" + + image = self.prepare_image_inputs() + + # If text has no image tokens, iamge should be `None` + with self.assertRaises(ValueError): + _ = processor(text=text_no_image, images=image, return_tensors="np") + + # We can't be sure what is users intention: if user wants one image per text OR two images for first text and no image for second text + with self.assertRaises(ValueError): + _ = processor(text=[text_single_image, text_single_image], images=[image, image], return_tensors="np") + + # The users is expected to be explicit about which image belong to which text by nesting the images list + out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np") + out_batch_oneimage = processor( + text=[text_single_image, text_single_image], images=[[image], [image]], return_tensors="np" + ) + self.assertListEqual( + out_batch_oneimage[self.images_input_name].tolist(), out_multiimages[self.images_input_name].tolist() + ) + + def test_pan_and_scan(self): + processor_components = self.prepare_components() + processor_kwargs = self.prepare_processor_dict() + processor = self.processor_class(**processor_components, **processor_kwargs) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="np", + do_pan_and_scan=True, + image_seq_length=2, + pan_and_scan_min_crop_size=10, + ) + + # base image + 4 crops + self.assertEqual(len(inputs[self.images_input_name]), 5) + self.assertEqual(len(inputs[self.text_input_name][0]), 67) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 3552439aeaa2..8d5124b0d8a2 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -783,7 +783,7 @@ def test_chat_template_single(self): self.assertListEqual(expected_output, formatted_prompt_tokenized) out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True) - self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) + self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"])) # Now test the ability to return dict messages[0][0]["content"].append( @@ -845,7 +845,7 @@ def test_chat_template_batched(self): return_dict=True, padding=True, ) - self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) + self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"])) # Now test the ability to return dict batched_messages[0][0]["content"].append( @@ -885,6 +885,7 @@ def test_chat_template_accepts_processing_kwargs(self): add_generation_prompt=True, tokenize=True, padding="max_length", + truncation=True, max_length=50, ) self.assertEqual(len(formatted_prompt_tokenized[0]), 50) @@ -982,7 +983,7 @@ def test_chat_template_video(self): self.assertListEqual(expected_output, formatted_prompt_tokenized) out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True) - self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"]) + self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"])) # Add video URL for return dict and load with `num_frames` arg messages[0][0]["content"][0] = { diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 87b7e8be0adf..68fd6434ee8f 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -226,6 +226,8 @@ "giou_loss_coefficient", ], "GPTNeoXConfig": ["rotary_emb_base"], + "Gemma3Config": ["boi_token_index", "eoi_token_index"], + "Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"], }