Skip to content

Commit cfa7c7c

Browse files
committed
📝 Add object detection processing tutorial
1 parent a8fb860 commit cfa7c7c

File tree

1 file changed

+165
-1
lines changed

1 file changed

+165
-1
lines changed

docs/source/image_process.mdx

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ Both parameter values default to 1000, which can be expensive if you are storing
3939

4040
## Data augmentation
4141

42-
🤗 Datasets can apply data augmentations from any library or package to your dataset. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).
42+
🤗 Datasets can apply data augmentations from any library or package to your dataset.
43+
44+
### Image Classification
45+
46+
First let's see how you can transform image classification datasets. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).
4347

4448
<Tip>
4549

@@ -88,3 +92,163 @@ Now you can take a look at the augmented image by indexing into the `pixel_value
8892
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png">
8993
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png"/>
9094
</div>
95+
96+
### Object Detection
97+
98+
Next, let's see how to apply transformations to object detection datasets. For this we'll use [Albumentations](https://albumentations.ai/docs/), following their object detection [tutorial](https://albumentations.ai/docs/examples/example_bboxes/).
99+
100+
To run these examples, make sure you have up to date versions of `albumentations` and `cv2` installed:
101+
102+
```
103+
pip install -U albumentations opencv-python
104+
```
105+
106+
In this example, the [`cppe-5`](https://huggingface.co/datasets/cppe-5) dataset is used, which is a dataset for identifying medical personal protective equipments (PPEs) in the context of the COVID-19 pandemic.
107+
108+
You can load the dataset and take a look at an example:
109+
110+
```py
111+
from datasets import load_dataset
112+
113+
>>> ds = load_dataset("cppe-5")
114+
>>> example = ds['train'][0]
115+
>>> example
116+
{'height': 663,
117+
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=943x663 at 0x7FC3DC756250>,
118+
'image_id': 15,
119+
'objects': {'area': [3796, 1596, 152768, 81002],
120+
'bbox': [[302.0, 109.0, 73.0, 52.0],
121+
[810.0, 100.0, 57.0, 28.0],
122+
[160.0, 31.0, 248.0, 616.0],
123+
[741.0, 68.0, 202.0, 401.0]],
124+
'category': [4, 4, 0, 0],
125+
'id': [114, 115, 116, 117]},
126+
'width': 943}
127+
```
128+
129+
The dataset has the following fields:
130+
131+
- `image`: PIL.Image.Image object containing the image.
132+
- `image_id`: The image ID.
133+
- `height`: The image height.
134+
- `width`: The image width.
135+
- `objects`: a dictionary containing bounding box metadata for the objects present on the image
136+
- `id`: the annotation id
137+
- `area`: the area of the bounding box
138+
- `bbox`: the object's bounding box (in the [coco](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco) format)
139+
- `category`: the object's category, with possible values including `Coverall (0)`, `Face_Shield (1)`, `Gloves (2)`, `Goggles (3)` and `Mask (4)`
140+
141+
You can visualize the bboxes on the image using some internal torch utilities. But, to do that, you will need to reference the `datasets.ClassLabel` feature associated with the category IDs so you can look up the string labels.
142+
143+
144+
```py
145+
>>> import torch
146+
>>> from torchvision.ops import box_convert
147+
>>> from torchvision.utils import draw_bounding_boxes
148+
>>> from torchvision.transforms.functional import pil_to_tensor, to_pil_image
149+
150+
>>> categories = ds['train'].features['objects'].feature['category']
151+
152+
>>> boxes_xywh = torch.tensor(example['objects']['bbox'])
153+
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
154+
>>> labels = [categories.int2str(x) for x in example['objects']['category']]
155+
>>> to_pil_image(
156+
... draw_bounding_boxes(
157+
... pil_to_tensor(example['image']),
158+
... boxes_xyxy,
159+
... colors="red",
160+
... labels=labels,
161+
... )
162+
... )
163+
```
164+
165+
<div class="flex justify-center">
166+
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example.png">
167+
</div>
168+
169+
170+
Using `albumentations`, we can apply transforms that will effect the Image while also updating the bounding boxes accordingly. In this case, the image is resized to (480, 480), flipped horizontally, and brightened.
171+
172+
Note that `albumentations` is expecting the image to be in BGR format, not RGB, so we'll have to convert our image first before applying the transform.
173+
174+
```py
175+
>>> import albumentations as A
176+
>>> import numpy as np
177+
178+
>>> transform = A.Compose([
179+
... A.Resize(480, 480),
180+
... A.HorizontalFlip(p=1.0),
181+
... A.RandomBrightnessContrast(p=1.0),
182+
... ], bbox_params=A.BboxParams(format='coco', label_fields=['category']))
183+
184+
>>> # RGB PIL Image -> BGR Numpy array
185+
>>> image = np.flip(np.array(example['image']), -1)
186+
>>> out = transform(
187+
... image=image,
188+
... bboxes=example['objects']['bbox'],
189+
... category=example['objects']['category'],
190+
... )
191+
```
192+
193+
Now when we visualize the result, the image should be flipped but the boxes should still be in the right places.
194+
195+
```py
196+
>>> image = torch.tensor(out['image']).flip(-1).permute(2, 0, 1)
197+
>>> boxes_xywh = torch.stack([torch.tensor(x) for x in out['bboxes']])
198+
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
199+
>>> labels = [categories.int2str(x) for x in out['category']]
200+
>>> to_pil_image(
201+
... draw_bounding_boxes(
202+
... image,
203+
... boxes_xyxy,
204+
... colors='red',
205+
... labels=labels
206+
... )
207+
... )
208+
```
209+
210+
<div class="flex justify-center">
211+
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed.png">
212+
</div>
213+
214+
Create a function to apply the transform to a batch of examples:
215+
216+
```py
217+
>>> def transforms(examples):
218+
... images, bboxes, categories = [], [], []
219+
... for image, objects in zip(examples['image'], examples['objects']):
220+
... image = np.array(image.convert("RGB"))[:, :, ::-1]
221+
... out = transform(
222+
... image=image,
223+
... bboxes=objects['bbox'],
224+
... category=objects['category']
225+
... )
226+
... images.append(torch.tensor(out['image']).flip(-1).permute(2, 0, 1))
227+
... bboxes.append(torch.tensor(out['bboxes']))
228+
... categories.append(out['category'])
229+
... return {'image': images, 'bbox': bboxes, 'category': categories}
230+
```
231+
232+
Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. This function is useful if you only need to access the examples once:
233+
234+
```py
235+
>>> ds['train'].set_transform(transforms)
236+
```
237+
238+
Verify the transform is working by visualizing the 10th example:
239+
240+
```py
241+
>>> example = ds['train'][10]
242+
>>> to_pil_image(
243+
... draw_bounding_boxes(
244+
... example['image'],
245+
... box_convert(example['bbox'], 'xywh', 'xyxy'),
246+
... colors='red',
247+
... labels=[categories.int2str(x) for x in example['category']]
248+
... )
249+
... )
250+
```
251+
252+
<div class="flex justify-center">
253+
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed_2.png">
254+
</div>

0 commit comments

Comments
 (0)