Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 165 additions & 1 deletion docs/source/image_process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ Both parameter values default to 1000, which can be expensive if you are storing

## Data augmentation

🤗 Datasets can apply data augmentations from any library or package to your dataset. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).
🤗 Datasets can apply data augmentations from any library or package to your dataset.

### Image Classification

First let's see how you can transform image classification datasets. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).

<Tip>

Expand Down Expand Up @@ -88,3 +92,163 @@ Now you can take a look at the augmented image by indexing into the `pixel_value
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png">
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png"/>
</div>

### Object Detection

Next, let's see how to apply transformations to object detection datasets. For this we'll use [Albumentations](https://albumentations.ai/docs/), following their object detection [tutorial](https://albumentations.ai/docs/examples/example_bboxes/).

To run these examples, make sure you have up to date versions of `albumentations` and `cv2` installed:

```
pip install -U albumentations opencv-python
```

In this example, the [`cppe-5`](https://huggingface.co/datasets/cppe-5) dataset is used, which is a dataset for identifying medical personal protective equipments (PPEs) in the context of the COVID-19 pandemic.

You can load the dataset and take a look at an example:

```py
from datasets import load_dataset

>>> ds = load_dataset("cppe-5")
>>> example = ds['train'][0]
>>> example
{'height': 663,
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=943x663 at 0x7FC3DC756250>,
'image_id': 15,
'objects': {'area': [3796, 1596, 152768, 81002],
'bbox': [[302.0, 109.0, 73.0, 52.0],
[810.0, 100.0, 57.0, 28.0],
[160.0, 31.0, 248.0, 616.0],
[741.0, 68.0, 202.0, 401.0]],
'category': [4, 4, 0, 0],
'id': [114, 115, 116, 117]},
'width': 943}
```

The dataset has the following fields:

- `image`: PIL.Image.Image object containing the image.
- `image_id`: The image ID.
- `height`: The image height.
- `width`: The image width.
- `objects`: a dictionary containing bounding box metadata for the objects present on the image
- `id`: the annotation id
- `area`: the area of the bounding box
- `bbox`: the object's bounding box (in the [coco](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco) format)
- `category`: the object's category, with possible values including `Coverall (0)`, `Face_Shield (1)`, `Gloves (2)`, `Goggles (3)` and `Mask (4)`

You can visualize the bboxes on the image using some internal torch utilities. But, to do that, you will need to reference the `datasets.ClassLabel` feature associated with the category IDs so you can look up the string labels.


```py
>>> import torch
>>> from torchvision.ops import box_convert
>>> from torchvision.utils import draw_bounding_boxes
>>> from torchvision.transforms.functional import pil_to_tensor, to_pil_image

>>> categories = ds['train'].features['objects'].feature['category']

>>> boxes_xywh = torch.tensor(example['objects']['bbox'])
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
>>> labels = [categories.int2str(x) for x in example['objects']['category']]
>>> to_pil_image(
... draw_bounding_boxes(
... pil_to_tensor(example['image']),
... boxes_xyxy,
... colors="red",
... labels=labels,
... )
... )
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example.png">
</div>


Using `albumentations`, we can apply transforms that will effect the Image while also updating the bounding boxes accordingly. In this case, the image is resized to (480, 480), flipped horizontally, and brightened.

Note that `albumentations` is expecting the image to be in BGR format, not RGB, so we'll have to convert our image first before applying the transform.

```py
>>> import albumentations as A
>>> import numpy as np

>>> transform = A.Compose([
... A.Resize(480, 480),
... A.HorizontalFlip(p=1.0),
... A.RandomBrightnessContrast(p=1.0),
... ], bbox_params=A.BboxParams(format='coco', label_fields=['category']))

>>> # RGB PIL Image -> BGR Numpy array
>>> image = np.flip(np.array(example['image']), -1)
>>> out = transform(
... image=image,
... bboxes=example['objects']['bbox'],
... category=example['objects']['category'],
... )
```

Now when we visualize the result, the image should be flipped but the boxes should still be in the right places.

```py
>>> image = torch.tensor(out['image']).flip(-1).permute(2, 0, 1)
>>> boxes_xywh = torch.stack([torch.tensor(x) for x in out['bboxes']])
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
>>> labels = [categories.int2str(x) for x in out['category']]
>>> to_pil_image(
... draw_bounding_boxes(
... image,
... boxes_xyxy,
... colors='red',
... labels=labels
... )
... )
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed.png">
</div>

Create a function to apply the transform to a batch of examples:

```py
>>> def transforms(examples):
... images, bboxes, categories = [], [], []
... for image, objects in zip(examples['image'], examples['objects']):
... image = np.array(image.convert("RGB"))[:, :, ::-1]
... out = transform(
... image=image,
... bboxes=objects['bbox'],
... category=objects['category']
... )
... images.append(torch.tensor(out['image']).flip(-1).permute(2, 0, 1))
... bboxes.append(torch.tensor(out['bboxes']))
... categories.append(out['category'])
... return {'image': images, 'bbox': bboxes, 'category': categories}
```

Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. Note that the randomness of data augmentation may return a different image if you access the same example twice. It is especially useful when training a model with several epochs.

```py
>>> ds['train'].set_transform(transforms)
```

Verify the transform is working by visualizing the 10th example:

```py
>>> example = ds['train'][10]
>>> to_pil_image(
... draw_bounding_boxes(
... example['image'],
... box_convert(example['bbox'], 'xywh', 'xyxy'),
... colors='red',
... labels=[categories.int2str(x) for x in example['category']]
... )
... )
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed_2.png">
</div>