Skip to content

Commit 9d6c064

Browse files
authored
Fix code snippet for Grounding DINO (#32229)
Fix code snippet for grounding-dino
1 parent 3a83ec4 commit 9d6c064

File tree

1 file changed

+34
-27
lines changed

1 file changed

+34
-27
lines changed

docs/source/en/model_doc/grounding-dino.md

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,33 +41,40 @@ The original code can be found [here](https://github.com/IDEA-Research/Grounding
4141
Here's how to use the model for zero-shot object detection:
4242

4343
```python
44-
import requests
45-
46-
import torch
47-
from PIL import Image
48-
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection,
49-
50-
model_id = "IDEA-Research/grounding-dino-tiny"
51-
52-
processor = AutoProcessor.from_pretrained(model_id)
53-
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
54-
55-
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
56-
image = Image.open(requests.get(image_url, stream=True).raw)
57-
# Check for cats and remote controls
58-
text = "a cat. a remote control."
59-
60-
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
61-
with torch.no_grad():
62-
outputs = model(**inputs)
63-
64-
results = processor.post_process_grounded_object_detection(
65-
outputs,
66-
inputs.input_ids,
67-
box_threshold=0.4,
68-
text_threshold=0.3,
69-
target_sizes=[image.size[::-1]]
70-
)
44+
>>> import requests
45+
46+
>>> import torch
47+
>>> from PIL import Image
48+
>>> from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
49+
50+
>>> model_id = "IDEA-Research/grounding-dino-tiny"
51+
>>> device = "cuda"
52+
53+
>>> processor = AutoProcessor.from_pretrained(model_id)
54+
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
55+
56+
>>> image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
57+
>>> image = Image.open(requests.get(image_url, stream=True).raw)
58+
>>> # Check for cats and remote controls
59+
>>> text = "a cat. a remote control."
60+
61+
>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device)
62+
>>> with torch.no_grad():
63+
... outputs = model(**inputs)
64+
65+
>>> results = processor.post_process_grounded_object_detection(
66+
... outputs,
67+
... inputs.input_ids,
68+
... box_threshold=0.4,
69+
... text_threshold=0.3,
70+
... target_sizes=[image.size[::-1]]
71+
... )
72+
>>> print(results)
73+
[{'boxes': tensor([[344.6959, 23.1090, 637.1833, 374.2751],
74+
[ 12.2666, 51.9145, 316.8582, 472.4392],
75+
[ 38.5742, 70.0015, 176.7838, 118.1806]], device='cuda:0'),
76+
'labels': ['a cat', 'a cat', 'a remote control'],
77+
'scores': tensor([0.4785, 0.4381, 0.4776], device='cuda:0')}]
7178
```
7279

7380
## Grounded SAM

0 commit comments

Comments
 (0)