Skip to content

Commit 1a6a5e6

Browse files
naterawlhoestqstevhliu
authored
Add object detection processing tutorial (#4710)
* 📝 Add object detection processing tutorial * Update docs/source/image_process.mdx Co-authored-by: Quentin Lhoest <[email protected]> * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * 📝 small formatting updates Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent c15b391 commit 1a6a5e6

File tree

1 file changed

+165
-1
lines changed

1 file changed

+165
-1
lines changed

docs/source/image_process.mdx

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ Both parameter values default to 1000, which can be expensive if you are storing
3939

4040
## Data augmentation
4141

42-
🤗 Datasets can apply data augmentations from any library or package to your dataset. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).
42+
🤗 Datasets can apply data augmentations from any library or package to your dataset.
43+
44+
### Image Classification
45+
46+
First let's see how you can transform image classification datasets. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).
4347

4448
<Tip>
4549

@@ -88,3 +92,163 @@ Now you can take a look at the augmented image by indexing into the `pixel_value
8892
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png">
8993
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png"/>
9094
</div>
95+
96+
### Object Detection
97+
98+
Object detection models identify something in an image, and object detection datasets are used for applications such as autonomous driving and detecting natural hazards like wildfire. This guide will show you how to apply transformations to an object detection dataset following the [tutorial](https://albumentations.ai/docs/examples/example_bboxes/) from [Albumentations](https://albumentations.ai/docs/).
99+
100+
To run these examples, make sure you have up-to-date versions of `albumentations` and `cv2` installed:
101+
102+
```
103+
pip install -U albumentations opencv-python
104+
```
105+
106+
In this example, you'll use the [`cppe-5`](https://huggingface.co/datasets/cppe-5) dataset for identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic.
107+
108+
Load the dataset and take a look at an example:
109+
110+
```py
111+
from datasets import load_dataset
112+
113+
>>> ds = load_dataset("cppe-5")
114+
>>> example = ds['train'][0]
115+
>>> example
116+
{'height': 663,
117+
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=943x663 at 0x7FC3DC756250>,
118+
'image_id': 15,
119+
'objects': {'area': [3796, 1596, 152768, 81002],
120+
'bbox': [[302.0, 109.0, 73.0, 52.0],
121+
[810.0, 100.0, 57.0, 28.0],
122+
[160.0, 31.0, 248.0, 616.0],
123+
[741.0, 68.0, 202.0, 401.0]],
124+
'category': [4, 4, 0, 0],
125+
'id': [114, 115, 116, 117]},
126+
'width': 943}
127+
```
128+
129+
The dataset has the following fields:
130+
131+
- `image`: PIL.Image.Image object containing the image.
132+
- `image_id`: The image ID.
133+
- `height`: The image height.
134+
- `width`: The image width.
135+
- `objects`: A dictionary containing bounding box metadata for the objects in the image:
136+
- `id`: The annotation id.
137+
- `area`: The area of the bounding box.
138+
- `bbox`: The object's bounding box (in the [coco](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco) format).
139+
- `category`: The object's category, with possible values including `Coverall (0)`, `Face_Shield (1)`, `Gloves (2)`, `Goggles (3)` and `Mask (4)`.
140+
141+
You can visualize the `bboxes` on the image using some internal torch utilities. To do that, you will need to reference the [`~datasets.ClassLabel`] feature associated with the category IDs so you can look up the string labels:
142+
143+
144+
```py
145+
>>> import torch
146+
>>> from torchvision.ops import box_convert
147+
>>> from torchvision.utils import draw_bounding_boxes
148+
>>> from torchvision.transforms.functional import pil_to_tensor, to_pil_image
149+
150+
>>> categories = ds['train'].features['objects'].feature['category']
151+
152+
>>> boxes_xywh = torch.tensor(example['objects']['bbox'])
153+
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
154+
>>> labels = [categories.int2str(x) for x in example['objects']['category']]
155+
>>> to_pil_image(
156+
... draw_bounding_boxes(
157+
... pil_to_tensor(example['image']),
158+
... boxes_xyxy,
159+
... colors="red",
160+
... labels=labels,
161+
... )
162+
... )
163+
```
164+
165+
<div class="flex justify-center">
166+
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example.png">
167+
</div>
168+
169+
170+
With `albumentations`, you can apply transforms that will affect the image while also updating the `bboxes` accordingly. In this case, the image is resized to (480, 480), flipped horizontally, and brightened.
171+
172+
`albumentations` expects the image to be in BGR format, not RGB, so you'll have to convert the image before applying the transform.
173+
174+
```py
175+
>>> import albumentations as A
176+
>>> import numpy as np
177+
178+
>>> transform = A.Compose([
179+
... A.Resize(480, 480),
180+
... A.HorizontalFlip(p=1.0),
181+
... A.RandomBrightnessContrast(p=1.0),
182+
... ], bbox_params=A.BboxParams(format='coco', label_fields=['category']))
183+
184+
>>> # RGB PIL Image -> BGR Numpy array
185+
>>> image = np.flip(np.array(example['image']), -1)
186+
>>> out = transform(
187+
... image=image,
188+
... bboxes=example['objects']['bbox'],
189+
... category=example['objects']['category'],
190+
... )
191+
```
192+
193+
Now when you visualize the result, the image should be flipped, but the `bboxes` should still be in the right places.
194+
195+
```py
196+
>>> image = torch.tensor(out['image']).flip(-1).permute(2, 0, 1)
197+
>>> boxes_xywh = torch.stack([torch.tensor(x) for x in out['bboxes']])
198+
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
199+
>>> labels = [categories.int2str(x) for x in out['category']]
200+
>>> to_pil_image(
201+
... draw_bounding_boxes(
202+
... image,
203+
... boxes_xyxy,
204+
... colors='red',
205+
... labels=labels
206+
... )
207+
... )
208+
```
209+
210+
<div class="flex justify-center">
211+
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed.png">
212+
</div>
213+
214+
Create a function to apply the transform to a batch of examples:
215+
216+
```py
217+
>>> def transforms(examples):
218+
... images, bboxes, categories = [], [], []
219+
... for image, objects in zip(examples['image'], examples['objects']):
220+
... image = np.array(image.convert("RGB"))[:, :, ::-1]
221+
... out = transform(
222+
... image=image,
223+
... bboxes=objects['bbox'],
224+
... category=objects['category']
225+
... )
226+
... images.append(torch.tensor(out['image']).flip(-1).permute(2, 0, 1))
227+
... bboxes.append(torch.tensor(out['bboxes']))
228+
... categories.append(out['category'])
229+
... return {'image': images, 'bbox': bboxes, 'category': categories}
230+
```
231+
232+
Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. The randomness of data augmentation may return a different image if you access the same example twice. It is especially useful when training a model for several epochs.
233+
234+
```py
235+
>>> ds['train'].set_transform(transforms)
236+
```
237+
238+
You can verify the transform works by visualizing the 10th example:
239+
240+
```py
241+
>>> example = ds['train'][10]
242+
>>> to_pil_image(
243+
... draw_bounding_boxes(
244+
... example['image'],
245+
... box_convert(example['bbox'], 'xywh', 'xyxy'),
246+
... colors='red',
247+
... labels=[categories.int2str(x) for x in example['category']]
248+
... )
249+
... )
250+
```
251+
252+
<div class="flex justify-center">
253+
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed_2.png">
254+
</div>

0 commit comments

Comments
 (0)