Skip to content

Commit 487dab1

Browse files
RyanMullinsArthurZuckerain-soph
authored
Shieldgemma2 (#36678)
* single commit * correct config * fixup * dummy pt * Use ShieldGemma2Config in conversion script * Update src/transformers/models/shieldgemma2/configuration_shieldgemma2.py * Adding shieldgemma2 to models.__init__.py * Adding ShieldGemma2 to main __init__.py * Update shieldgemma2.md * Update shieldgemma2.md * Adding tests. Addressing review feedback. * Minor docs update * Fixing code quality feedback from CI * Fixing empty messages bug reported by ghunkins --------- Co-authored-by: Arthur Zucker <[email protected]> Co-authored-by: Ren Pang <[email protected]>
1 parent a63e92e commit 487dab1

19 files changed

+1459
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
2+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
the License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
13+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
14+
rendered properly in your Markdown viewer.
15+
16+
-->
17+
18+
# ShieldGemma 2
19+
20+
## Overview
21+
22+
The ShieldGemma 2 model was proposed in a forthcoming technical report by Google. ShieldGemma 2 is built on [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3), is a 4 billion (4B) parameter model that checks the safety of both synthetic and natural images against key categories to help you build robust datasets and models. With this addition to the Gemma family of models, researchers and developers can now easily minimize the risk of harmful content in their models across key areas of harm as defined below:
23+
24+
- No Sexually Explicit content: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).
25+
- No Dangerous Content: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).
26+
- No Violence/Gore content: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death).
27+
28+
We recommend using ShieldGemma 2 as an input filter to vision language models, or as an output filter of image generation systems. To train a robust image safety model, we curated training datasets of natural and synthetic images and instruction-tuned Gemma 3 to demonstrate strong performance.
29+
30+
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins).
31+
32+
## Usage Example
33+
34+
- ShieldGemma 2 provides a Processor that accepts a list of `images` and an optional list of `policies` as input, and constructs a batch of prompts as the product of these two lists using the provided chat template.
35+
- You can extend ShieldGemma's built-in in policies with the `custom_policies` argument to the Processor. Using the same key as one of the built-in policies will overwrite that policy with your custom defintion.
36+
- ShieldGemma 2 does not support the image cropping capabilities used by Gemma 3.
37+
38+
### Classification against Built-in Policies
39+
40+
```python
41+
from PIL import Image
42+
import requests
43+
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
44+
45+
model_id = "google/shieldgemma-2-4b-it"
46+
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
47+
processor = AutoProcessor.from_pretrained(model_id)
48+
49+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
50+
image = Image.open(requests.get(url, stream=True).raw)
51+
52+
inputs = processor(images=[image], return_tensors="pt").to(model.device)
53+
54+
output = model(**inputs)
55+
print(output.probabilities)
56+
```
57+
58+
### Classification against Custom Policies
59+
60+
```python
61+
from PIL import Image
62+
import requests
63+
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
64+
65+
model_id = "google/shieldgemma-2-4b-it"
66+
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
67+
processor = AutoProcessor.from_pretrained(model_id)
68+
69+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
70+
image = Image.open(requests.get(url, stream=True).raw)
71+
72+
custom_policies = {
73+
"key_a": "descrition_a",
74+
"key_b": "descrition_b",
75+
}
76+
77+
inputs = processor(
78+
images=[image],
79+
custom_policies=custom_policies,
80+
policies=["dangerous", "key_a", "key_b"],
81+
return_tensors="pt",
82+
).to(model.device)
83+
84+
output = model(**inputs)
85+
print(output.probabilities)
86+
```
87+
88+
89+
## ShieldGemma2Processor
90+
91+
[[autodoc]] ShieldGemma2Processor
92+
93+
## ShieldGemma2Config
94+
95+
[[autodoc]] ShieldGemma2Config
96+
97+
## ShieldGemma2ForImageClassification
98+
99+
[[autodoc]] ShieldGemma2ForImageClassification
100+
- forward

src/transformers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,10 @@
774774
"models.seggpt": ["SegGptConfig"],
775775
"models.sew": ["SEWConfig"],
776776
"models.sew_d": ["SEWDConfig"],
777+
"models.shieldgemma2": [
778+
"ShieldGemma2Config",
779+
"ShieldGemma2Processor",
780+
],
777781
"models.siglip": [
778782
"SiglipConfig",
779783
"SiglipProcessor",
@@ -3581,6 +3585,7 @@
35813585
"SEWDPreTrainedModel",
35823586
]
35833587
)
3588+
_import_structure["models.shieldgemma2"].append("ShieldGemma2ForImageClassification")
35843589
_import_structure["models.siglip"].extend(
35853590
[
35863591
"SiglipForImageClassification",
@@ -5982,6 +5987,10 @@
59825987
from .models.seggpt import SegGptConfig
59835988
from .models.sew import SEWConfig
59845989
from .models.sew_d import SEWDConfig
5990+
from .models.shieldgemma2 import (
5991+
ShieldGemma2Config,
5992+
ShieldGemma2Processor,
5993+
)
59855994
from .models.siglip import (
59865995
SiglipConfig,
59875996
SiglipProcessor,
@@ -8350,6 +8359,9 @@
83508359
SEWDModel,
83518360
SEWDPreTrainedModel,
83528361
)
8362+
from .models.shieldgemma2 import (
8363+
ShieldGemma2ForImageClassification,
8364+
)
83538365
from .models.siglip import (
83548366
SiglipForImageClassification,
83558367
SiglipModel,

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@
247247
seggpt,
248248
sew,
249249
sew_d,
250+
shieldgemma2,
250251
siglip,
251252
siglip2,
252253
smolvlm,

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@
274274
("seggpt", "SegGptConfig"),
275275
("sew", "SEWConfig"),
276276
("sew-d", "SEWDConfig"),
277+
("shieldgemma2", "ShieldGemma2Config"),
277278
("siglip", "SiglipConfig"),
278279
("siglip2", "Siglip2Config"),
279280
("siglip_vision_model", "SiglipVisionConfig"),
@@ -625,6 +626,7 @@
625626
("seggpt", "SegGPT"),
626627
("sew", "SEW"),
627628
("sew-d", "SEW-D"),
629+
("shieldgemma2", "Shieldgemma2"),
628630
("siglip", "SigLIP"),
629631
("siglip2", "SigLIP2"),
630632
("siglip2_vision_model", "Siglip2VisionModel"),

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
("sam", ("SamImageProcessor",)),
138138
("segformer", ("SegformerImageProcessor",)),
139139
("seggpt", ("SegGptImageProcessor",)),
140+
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
140141
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
141142
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
142143
("superglue", "SuperGlueImageProcessor"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,7 @@
727727
("regnet", "RegNetForImageClassification"),
728728
("resnet", "ResNetForImageClassification"),
729729
("segformer", "SegformerForImageClassification"),
730+
("shieldgemma2", "ShieldGemma2ForImageClassification"),
730731
("siglip", "SiglipForImageClassification"),
731732
("siglip2", "Siglip2ForImageClassification"),
732733
("swiftformer", "SwiftFormerForImageClassification"),
@@ -849,6 +850,7 @@
849850
("pixtral", "LlavaForConditionalGeneration"),
850851
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
851852
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
853+
("shieldgemma2", "Gemma3ForConditionalGeneration"),
852854
("smolvlm", "SmolVLMForConditionalGeneration"),
853855
("udop", "UdopForConditionalGeneration"),
854856
("vipllava", "VipLlavaForConditionalGeneration"),

src/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
("seamless_m4t", "SeamlessM4TProcessor"),
102102
("sew", "Wav2Vec2Processor"),
103103
("sew-d", "Wav2Vec2Processor"),
104+
("shieldgemma2", "ShieldGemma2Processor"),
104105
("siglip", "SiglipProcessor"),
105106
("siglip2", "Siglip2Processor"),
106107
("speech_to_text", "Speech2TextProcessor"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,13 @@
493493
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
494494
),
495495
),
496+
(
497+
"shieldgemma2",
498+
(
499+
"GemmaTokenizer" if is_sentencepiece_available() else None,
500+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
501+
),
502+
),
496503
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
497504
(
498505
"siglip2",
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_shieldgemma2 import *
22+
from .modeling_shieldgemma2 import *
23+
from .processing_shieldgemma2 import *
24+
else:
25+
import sys
26+
27+
_file = globals()["__file__"]
28+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# coding=utf-8
2+
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
3+
#
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from ...configuration_utils import PretrainedConfig
18+
from ...utils import logging
19+
from ..auto import CONFIG_MAPPING, AutoConfig
20+
21+
22+
logger = logging.get_logger(__name__)
23+
24+
25+
class ShieldGemma2Config(PretrainedConfig):
26+
r"""
27+
This is the configuration class to store the configuration of a [`ShieldGemma2ForImageClassification`]. It is used to instantiate an
28+
ShieldGemma2ForImageClassification according to the specified arguments, defining the model architecture. Instantiating a configuration
29+
with the defaults will yield a similar configuration to that of the shieldgemma-2-4b-it.
30+
31+
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
32+
33+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34+
documentation from [`PretrainedConfig`] for more information.
35+
36+
Args:
37+
text_config (`Union[ShieldGemma2TextConfig, dict]`, *optional*):
38+
The config object of the text backbone.
39+
vision_config (`Union[AutoConfig, dict]`, *optional*):
40+
Custom vision config or dict.
41+
mm_tokens_per_image (`int`, *optional*, defaults to 256):
42+
The number of tokens per image embedding.
43+
boi_token_index (`int`, *optional*, defaults to 255999):
44+
The begin-of-image token index to wrap the image prompt.
45+
eoi_token_index (`int`, *optional*, defaults to 256000):
46+
The end-of-image token index to wrap the image prompt.
47+
image_token_index (`int`, *optional*, defaults to 262144):
48+
The image token index to encode the image prompt.
49+
initializer_range (`float`, *optional*, defaults to 0.02):
50+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51+
52+
53+
Example:
54+
55+
```python
56+
>>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig
57+
58+
>>> # Initializing a Siglip-like vision config
59+
>>> vision_config = SiglipVisionConfig()
60+
61+
>>> # Initializing a ShieldGemma2 Text config
62+
>>> text_config = ShieldGemma2TextConfig()
63+
64+
>>> # Initializing a ShieldGemma2 gemma-3-4b style configuration
65+
>>> configuration = ShieldGemma2Config(vision_config, text_config)
66+
67+
>>> # Initializing a model from the gemma-3-4b style configuration
68+
>>> model = ShieldGemma2TextConfig(configuration)
69+
70+
>>> # Accessing the model configuration
71+
>>> configuration = model.config
72+
```"""
73+
74+
model_type = "shieldgemma2"
75+
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
76+
77+
def __init__(
78+
self,
79+
text_config=None,
80+
vision_config=None,
81+
mm_tokens_per_image: int = 256,
82+
boi_token_index: int = 255_999,
83+
eoi_token_index: int = 256_000,
84+
image_token_index: int = 262_144,
85+
initializer_range: float = 0.02,
86+
**kwargs,
87+
):
88+
if isinstance(vision_config, dict):
89+
vision_config["model_type"] = (
90+
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
91+
)
92+
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
93+
elif vision_config is None:
94+
vision_config = CONFIG_MAPPING["siglip_vision_model"]()
95+
96+
self.vision_config = vision_config
97+
98+
if isinstance(text_config, dict):
99+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma3_text"
100+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
101+
elif text_config is None:
102+
text_config = CONFIG_MAPPING["gemma3_text"]()
103+
104+
self.text_config = text_config
105+
self.vision_config = vision_config
106+
self.mm_tokens_per_image = mm_tokens_per_image
107+
self.boi_token_index = boi_token_index
108+
self.eoi_token_index = eoi_token_index
109+
self.image_token_index = image_token_index
110+
self.initializer_range = initializer_range
111+
112+
super().__init__(**kwargs)
113+
114+
115+
__all__ = ["ShieldGemma2Config"]

0 commit comments

Comments
 (0)