Skip to content

Commit 03efe5e

Browse files
committed
Fixing empty messages bug reported by ghunkins
1 parent 48ac9e9 commit 03efe5e

File tree

5 files changed

+59
-55
lines changed

5 files changed

+59
-55
lines changed

docs/source/en/model_doc/shieldgemma2.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,16 @@ output = model(**inputs)
8585
print(output.probabilities)
8686
```
8787

88-
## ShieldGemmaForImageClassification
8988

90-
[[autodoc]] Shieldgemma2ForTokenClassification
89+
## ShieldGemma2Processor
90+
91+
[[autodoc]] ShieldGemma2Processor
92+
93+
## ShieldGemma2Config
94+
95+
[[autodoc]] ShieldGemma2Config
96+
97+
## ShieldGemma2ForImageClassification
98+
99+
[[autodoc]] ShieldGemma2ForImageClassification
91100
- forward

src/transformers/models/shieldgemma2/modeling_shieldgemma2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def forward(
212212
output_hidden_states=output_hidden_states,
213213
return_dict=return_dict,
214214
logits_to_keep=logits_to_keep,
215-
**lm_kwargs
215+
**lm_kwargs,
216216
)
217217
logits = outputs.logits
218218
selected_logits = logits[:, -1, [self.yes_token_index, self.no_token_index]]

src/transformers/models/shieldgemma2/processing_shieldgemma2.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@
2525

2626
logger = logging.get_logger(__name__)
2727

28+
DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
29+
"dangerous": (
30+
"No Dangerous Content: The image shall not contain content that"
31+
" facilitates or encourages activities that could cause real-world harm"
32+
" (e.g., building firearms and explosive devices, promotion of"
33+
" terrorism, instructions for suicide)."
34+
),
35+
"sexual": (
36+
"No Sexually Explicit content: The image shall not contain content that"
37+
" depicts explicit or graphic sexual acts (e.g., pornography, erotic"
38+
" nudity, depictions of rape or sexual assault)."
39+
),
40+
"violence": (
41+
"No Violence/Gore content: The image shall not contain content that"
42+
" depicts shocking, sensational, or gratuitous violence (e.g.,"
43+
" excessive blood and gore, gratuitous violence against animals,"
44+
" extreme injury or moment of death)."
45+
),
46+
}
47+
2848

2949
class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
3050
policies: Optional[Sequence[str]]
@@ -40,15 +60,8 @@ class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
4060

4161

4262
class ShieldGemma2Processor(Gemma3Processor):
43-
4463
def __init__(
45-
self,
46-
image_processor,
47-
tokenizer,
48-
chat_template = None,
49-
image_seq_length = 256,
50-
policy_definitions = None,
51-
**kwargs
64+
self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs
5265
):
5366
"""A processor for the ShieldGemma 2 model.
5467
@@ -65,10 +78,10 @@ def __init__(
6578
the base policies ShieldGemma was trained on.
6679
"""
6780
super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs)
68-
if policy_definitions:
69-
self.policy_definitions = policy_definitions
81+
if policy_definitions is None:
82+
self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES
7083
else:
71-
self.policy_definitions = {}
84+
self.policy_definitions = policy_definitions
7285

7386
def __call__(
7487
self,
@@ -129,7 +142,6 @@ def __call__(
129142
text_kwargs["padding"] = kwargs.pop("padding", True)
130143
text_kwargs["padding_side"] = kwargs.pop("padding_side", "left")
131144

132-
133145
policy_definitions: Mapping[str, str] = {
134146
**self.policy_definitions,
135147
**kwargs.get("custom_policies", {}),
@@ -138,6 +150,7 @@ def __call__(
138150
if (policies := kwargs.get("policies")) is None:
139151
policies = list(policy_definitions.keys())
140152

153+
# TODO(ryanmullins): Support images from PIL or URLs.
141154
messages = []
142155
expanded_images = []
143156
for img in images:

tests/models/shieldgemma2/test_modeling_shieldgemma2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch Gemma3 model."""
1616

17-
from io import BytesIO
1817
import unittest
18+
from io import BytesIO
1919

20-
from PIL import Image
2120
import requests
21+
from PIL import Image
2222

2323
from transformers import is_torch_available
2424
from transformers.testing_utils import (
@@ -31,14 +31,14 @@
3131

3232
if is_torch_available():
3333
import torch
34+
3435
from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor
3536

3637

3738
@slow
3839
@require_torch_gpu
3940
# @require_read_token
4041
class ShieldGemma2IntegrationTest(unittest.TestCase):
41-
4242
def tearDown(self):
4343
cleanup(torch_device, gc_collect=True)
4444

tests/models/shieldgemma2/test_processing_shieldgemma2.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections.abc import Mapping
1615
import json
1716
import os
1817
import shutil
1918
import tempfile
2019
import unittest
20+
from collections.abc import Mapping
2121

2222
from parameterized import parameterized
23+
2324
from transformers import GemmaTokenizer, ShieldGemma2Processor
2425
from transformers.testing_utils import get_tests_dir, require_vision
2526
from transformers.utils import is_vision_available
@@ -95,29 +96,6 @@ def prepare_processor_dict(self):
9596
"policy_definitions": _SHIELDGEMMA2_POLICIES,
9697
}
9798

98-
def test_unstructured_kwargs(self):
99-
if "image_processor" not in self.processor_class.attributes:
100-
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
101-
processor_components = self.prepare_components()
102-
processor_kwargs = self.prepare_processor_dict()
103-
processor = self.processor_class(**processor_components, **processor_kwargs)
104-
self.skip_processor_without_typed_kwargs(processor)
105-
106-
input_str = self.prepare_text_inputs()
107-
image_input = self.prepare_image_inputs()
108-
inputs = processor(
109-
text=input_str,
110-
images=image_input,
111-
return_tensors="pt",
112-
do_rescale=True,
113-
rescale_factor=-1,
114-
padding="max_length",
115-
max_length=674,
116-
)
117-
118-
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
119-
self.assertEqual(inputs[self.text_input_name].shape[-1], 674)
120-
12199
def test_policy_definitions_saved_in_config(self):
122100
processor_config_path = os.path.join(self.tmpdirname, "processor_config.json")
123101

@@ -128,11 +106,13 @@ def test_policy_definitions_saved_in_config(self):
128106
self.assertIn("policy_definitions", json_dict)
129107
self.assertIs(len(json_dict["policy_definitions"]), 3)
130108

131-
@parameterized.expand([
132-
("all_policies", None, 3),
133-
("selected_policies", ["dangerous", "violence"], 2),
134-
("single_policy", ["sexual"], 1),
135-
])
109+
@parameterized.expand(
110+
[
111+
("all_policies", None, 3),
112+
("selected_policies", ["dangerous", "violence"], 2),
113+
("single_policy", ["sexual"], 1),
114+
]
115+
)
136116
def test_with_default_policies(self, name, policies, expected_batch_size):
137117
processor = self.get_processor()
138118

@@ -144,14 +124,16 @@ def test_with_default_policies(self, name, policies, expected_batch_size):
144124
self.assertEqual(len(processed_inputs[self.text_input_name]), expected_batch_size)
145125
self.assertEqual(len(processed_inputs[self.images_input_name]), expected_batch_size)
146126

147-
@parameterized.expand([
148-
("all_policies", None, 6),
149-
("selected_policies_from_both", ["cbrne", "dangerous", "specialized_advice", "violence"], 4),
150-
("selected_policies_from_custom", ["cbrne", "specialized_advice"], 2),
151-
("selected_policies_from_default", ["dangerous", "violence"], 2),
152-
("single_policy_from_custom", ["ip"], 1),
153-
("single_policy_from_default", ["sexual"], 1),
154-
])
127+
@parameterized.expand(
128+
[
129+
("all_policies", None, 6),
130+
("selected_policies_from_both", ["cbrne", "dangerous", "specialized_advice", "violence"], 4),
131+
("selected_policies_from_custom", ["cbrne", "specialized_advice"], 2),
132+
("selected_policies_from_default", ["dangerous", "violence"], 2),
133+
("single_policy_from_custom", ["ip"], 1),
134+
("single_policy_from_default", ["sexual"], 1),
135+
]
136+
)
155137
def test_with_custom_policies(self, name, policies, expected_batch_size):
156138
processor = self.get_processor()
157139

0 commit comments

Comments
 (0)