Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions docs/source/en/model_doc/shieldgemma2.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,50 @@ This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins)
- You can extend ShieldGemma's built-in in policies with the `custom_policies` argument to the Processor. Using the same key as one of the built-in policies will overwrite that policy with your custom defintion.
- ShieldGemma 2 does not support the image cropping capabilities used by Gemma 3.

### Classificaiton against Built-in Policies
### Classification against Built-in Policies

```python
from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification

model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)

image_1 = # An image you have loaded
image_2 = # An image you have loaded
image_3 = # An image you have loaded
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=[image_1, image_2, image_3]).to(model.device)
inputs = processor(images=[image], return_tensors="pt").to(model.device)

output = model(**inputs)
print(output.probabilities)
```

### Classificaiton against Custom Policies
### Classification against Custom Policies

```python
from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification

model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)

image_1 = # An image you have loaded
image_2 = # An image you have loaded
image_3 = # An image you have loaded
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)

custom_policies = {
"key_a": "descrition_a",
"key_b": "descrition_b",
}

inputs = processor(
images=[image_1, image_2, image_3],
images=[image],
custom_policies=custom_policies,
policies=["dangerous", "key_a", "key_b"],
return_tensors="pt",
).to(model.device)

output = model(**inputs)
Expand Down