Skip to content

Commit 2f71b9c

Browse files
authored
Add image classification processing guide (#4748)
* 📝 first draft * 🖍 apply reviews
1 parent 696555f commit 2f71b9c

File tree

4 files changed

+252
-187
lines changed

4 files changed

+252
-187
lines changed

docs/source/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
title: Load image data
5757
- local: image_process
5858
title: Process image data
59+
- local: image_classification
60+
title: Image classification
61+
- local: object_detection
62+
title: Object detection
5963
title: "Vision"
6064
- sections:
6165
- local: nlp_load
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Image classification
2+
3+
Image classification datasets are used to train a model to classify an entire image. There are a wide variety of applications enabled by these datasets such as identifying endangered wildlife species or screening for disease in medical images. This guide will show you how to apply transformations to an image classification dataset.
4+
5+
Before you start, make sure you have up-to-date versions of `albumentations` and `cv2` installed:
6+
7+
```bash
8+
pip install -U albumentations opencv-python
9+
```
10+
11+
This guide uses the [Beans](https://huggingface.co/datasets/beans) dataset for identifying the type of bean plant disease based on an image of its leaf.
12+
13+
Load the dataset and take a look at an example:
14+
15+
```py
16+
>>> from datasets import load_dataset
17+
18+
>>> dataset = load_dataset("beans")
19+
>>> dataset["train"][10]
20+
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x7F8D2F4D7A10>,
21+
'image_file_path': '/root/.cache/huggingface/datasets/downloads/extracted/b0a21163f78769a2cf11f58dfc767fb458fc7cea5c05dccc0144a2c0f0bc1292/train/angular_leaf_spot/angular_leaf_spot_train.204.jpg',
22+
'labels': 0}
23+
```
24+
25+
The dataset has three fields:
26+
27+
* `image`: a PIL image object.
28+
* `image_file_path`: the path to the image file.
29+
* `labels`: the label or category of the image.
30+
31+
Next, check out an image:
32+
33+
<div class="flex justify-center">
34+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/img_clf.png">
35+
</div>
36+
37+
Now apply some augmentations with `albumentations`. You'll randomly crop the image, flip it horizontally, and adjust its brightness.
38+
39+
```py
40+
>>> import cv2
41+
>>> import albumentations as A
42+
>>> import numpy as np
43+
44+
>>> transform = A.Compose([
45+
... A.RandomCrop(width=256, height=256),
46+
... A.HorizontalFlip(p=0.5),
47+
... A.RandomBrightnessContrast(p=0.2),
48+
... ])
49+
```
50+
51+
Create a function to apply the transformation to the images:
52+
53+
```py
54+
>>> def transforms(examples):
55+
... examples["pixel_values"] = [
56+
... transform(image=np.array(image))["image"] for image in examples["image"]
57+
... ]
58+
...
59+
... return examples
60+
```
61+
62+
Use the [`~Dataset.set_transform`] function to apply the transformation on-the-fly to batches of the dataset to consume less disk space:
63+
64+
```py
65+
>>> dataset.set_transform(transforms)
66+
```
67+
68+
You can verify the transformation worked by indexing into the `pixel_values` of the first example:
69+
70+
```py
71+
>>> import numpy as np
72+
>>> import matplotlib.pyplot as plt
73+
74+
>>> img = dataset["train"][0]["pixel_values"]
75+
>>> plt.imshow(img)
76+
```
77+
78+
<div class="flex justify-center">
79+
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/img_clf_aug.png">
80+
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/img_clf_aug.png"/>
81+
</div>

docs/source/image_process.mdx

Lines changed: 8 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This guide shows specific methods for processing image datasets. Learn how to:
44

55
- Use [`~Dataset.map`] with image dataset.
6-
- Apply data augmentations to your dataset with [`~Dataset.set_transform`].
6+
- Apply data augmentations to a dataset with [`~Dataset.set_transform`].
77

88
For a guide on how to process any type of dataset, take a look at the <a class="underline decoration-sky-400 decoration-2 font-semibold" href="./process">general process guide</a>.
99

@@ -37,21 +37,17 @@ The cache file saves time because you don't have to execute the same transform t
3737

3838
Both parameter values default to 1000, which can be expensive if you are storing images. Lower these values to use less memory when you use [`~Dataset.map`].
3939

40-
## Data augmentation
40+
## Apply transforms
4141

42-
🤗 Datasets can apply data augmentations from any library or package to your dataset.
43-
44-
### Image Classification
45-
46-
First let's see how you can transform image classification datasets. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).
42+
🤗 Datasets applies data augmentations from any library or package to your dataset. Transforms can be applied on-the-fly on batches of data with [`~Dataset.set_transform`], which consumes less disk space.
4743

4844
<Tip>
4945

50-
Feel free to use other data augmentation libraries like [Albumentations](https://albumentations.ai/docs/), [Kornia](https://kornia.readthedocs.io/en/latest/), and [imgaug](https://imgaug.readthedocs.io/en/latest/).
46+
The following example uses [torchvision](https://pytorch.org/vision/stable/index.html), but feel free to use other data augmentation libraries like [Albumentations](https://albumentations.ai/docs/), [Kornia](https://kornia.readthedocs.io/en/latest/), and [imgaug](https://imgaug.readthedocs.io/en/latest/).
5147

5248
</Tip>
5349

54-
As an example, try to apply a [`ColorJitter`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ColorJitter) transform to change the color properties of the image randomly:
50+
For example, if you'd like to change the color properties of an image randomly:
5551

5652
```py
5753
>>> from torchvision.transforms import Compose, ColorJitter, ToTensor
@@ -64,191 +60,16 @@ As an example, try to apply a [`ColorJitter`](https://pytorch.org/vision/stable/
6460
... )
6561
```
6662

67-
Create a function to apply the `ColorJitter` transform to an image:
63+
Create a function to apply the `ColorJitter` transform:
6864

6965
```py
7066
>>> def transforms(examples):
7167
... examples["pixel_values"] = [jitter(image.convert("RGB")) for image in examples["image"]]
7268
... return examples
7369
```
7470

75-
Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. This function is useful if you only need to access the examples once:
71+
Apply the transform with the [`~Dataset.set_transform`] function:
7672

7773
```py
7874
>>> dataset.set_transform(transforms)
79-
```
80-
81-
Now you can take a look at the augmented image by indexing into the `pixel_values`:
82-
83-
```py
84-
>>> import numpy as np
85-
>>> import matplotlib.pyplot as plt
86-
87-
>>> img = dataset[0]["pixel_values"]
88-
>>> plt.imshow(img.permute(1, 2, 0))
89-
```
90-
91-
<div class="flex justify-center">
92-
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png">
93-
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png"/>
94-
</div>
95-
96-
### Object Detection
97-
98-
Object detection models identify something in an image, and object detection datasets are used for applications such as autonomous driving and detecting natural hazards like wildfire. This guide will show you how to apply transformations to an object detection dataset following the [tutorial](https://albumentations.ai/docs/examples/example_bboxes/) from [Albumentations](https://albumentations.ai/docs/).
99-
100-
To run these examples, make sure you have up-to-date versions of `albumentations` and `cv2` installed:
101-
102-
```
103-
pip install -U albumentations opencv-python
104-
```
105-
106-
In this example, you'll use the [`cppe-5`](https://huggingface.co/datasets/cppe-5) dataset for identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic.
107-
108-
Load the dataset and take a look at an example:
109-
110-
```py
111-
from datasets import load_dataset
112-
113-
>>> ds = load_dataset("cppe-5")
114-
>>> example = ds['train'][0]
115-
>>> example
116-
{'height': 663,
117-
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=943x663 at 0x7FC3DC756250>,
118-
'image_id': 15,
119-
'objects': {'area': [3796, 1596, 152768, 81002],
120-
'bbox': [[302.0, 109.0, 73.0, 52.0],
121-
[810.0, 100.0, 57.0, 28.0],
122-
[160.0, 31.0, 248.0, 616.0],
123-
[741.0, 68.0, 202.0, 401.0]],
124-
'category': [4, 4, 0, 0],
125-
'id': [114, 115, 116, 117]},
126-
'width': 943}
127-
```
128-
129-
The dataset has the following fields:
130-
131-
- `image`: PIL.Image.Image object containing the image.
132-
- `image_id`: The image ID.
133-
- `height`: The image height.
134-
- `width`: The image width.
135-
- `objects`: A dictionary containing bounding box metadata for the objects in the image:
136-
- `id`: The annotation id.
137-
- `area`: The area of the bounding box.
138-
- `bbox`: The object's bounding box (in the [coco](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco) format).
139-
- `category`: The object's category, with possible values including `Coverall (0)`, `Face_Shield (1)`, `Gloves (2)`, `Goggles (3)` and `Mask (4)`.
140-
141-
You can visualize the `bboxes` on the image using some internal torch utilities. To do that, you will need to reference the [`~datasets.ClassLabel`] feature associated with the category IDs so you can look up the string labels:
142-
143-
144-
```py
145-
>>> import torch
146-
>>> from torchvision.ops import box_convert
147-
>>> from torchvision.utils import draw_bounding_boxes
148-
>>> from torchvision.transforms.functional import pil_to_tensor, to_pil_image
149-
150-
>>> categories = ds['train'].features['objects'].feature['category']
151-
152-
>>> boxes_xywh = torch.tensor(example['objects']['bbox'])
153-
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
154-
>>> labels = [categories.int2str(x) for x in example['objects']['category']]
155-
>>> to_pil_image(
156-
... draw_bounding_boxes(
157-
... pil_to_tensor(example['image']),
158-
... boxes_xyxy,
159-
... colors="red",
160-
... labels=labels,
161-
... )
162-
... )
163-
```
164-
165-
<div class="flex justify-center">
166-
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example.png">
167-
</div>
168-
169-
170-
With `albumentations`, you can apply transforms that will affect the image while also updating the `bboxes` accordingly. In this case, the image is resized to (480, 480), flipped horizontally, and brightened.
171-
172-
`albumentations` expects the image to be in BGR format, not RGB, so you'll have to convert the image before applying the transform.
173-
174-
```py
175-
>>> import albumentations as A
176-
>>> import numpy as np
177-
178-
>>> transform = A.Compose([
179-
... A.Resize(480, 480),
180-
... A.HorizontalFlip(p=1.0),
181-
... A.RandomBrightnessContrast(p=1.0),
182-
... ], bbox_params=A.BboxParams(format='coco', label_fields=['category']))
183-
184-
>>> # RGB PIL Image -> BGR Numpy array
185-
>>> image = np.flip(np.array(example['image']), -1)
186-
>>> out = transform(
187-
... image=image,
188-
... bboxes=example['objects']['bbox'],
189-
... category=example['objects']['category'],
190-
... )
191-
```
192-
193-
Now when you visualize the result, the image should be flipped, but the `bboxes` should still be in the right places.
194-
195-
```py
196-
>>> image = torch.tensor(out['image']).flip(-1).permute(2, 0, 1)
197-
>>> boxes_xywh = torch.stack([torch.tensor(x) for x in out['bboxes']])
198-
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
199-
>>> labels = [categories.int2str(x) for x in out['category']]
200-
>>> to_pil_image(
201-
... draw_bounding_boxes(
202-
... image,
203-
... boxes_xyxy,
204-
... colors='red',
205-
... labels=labels
206-
... )
207-
... )
208-
```
209-
210-
<div class="flex justify-center">
211-
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed.png">
212-
</div>
213-
214-
Create a function to apply the transform to a batch of examples:
215-
216-
```py
217-
>>> def transforms(examples):
218-
... images, bboxes, categories = [], [], []
219-
... for image, objects in zip(examples['image'], examples['objects']):
220-
... image = np.array(image.convert("RGB"))[:, :, ::-1]
221-
... out = transform(
222-
... image=image,
223-
... bboxes=objects['bbox'],
224-
... category=objects['category']
225-
... )
226-
... images.append(torch.tensor(out['image']).flip(-1).permute(2, 0, 1))
227-
... bboxes.append(torch.tensor(out['bboxes']))
228-
... categories.append(out['category'])
229-
... return {'image': images, 'bbox': bboxes, 'category': categories}
230-
```
231-
232-
Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. The randomness of data augmentation may return a different image if you access the same example twice. It is especially useful when training a model for several epochs.
233-
234-
```py
235-
>>> ds['train'].set_transform(transforms)
236-
```
237-
238-
You can verify the transform works by visualizing the 10th example:
239-
240-
```py
241-
>>> example = ds['train'][10]
242-
>>> to_pil_image(
243-
... draw_bounding_boxes(
244-
... example['image'],
245-
... box_convert(example['bbox'], 'xywh', 'xyxy'),
246-
... colors='red',
247-
... labels=[categories.int2str(x) for x in example['category']]
248-
... )
249-
... )
250-
```
251-
252-
<div class="flex justify-center">
253-
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed_2.png">
254-
</div>
75+
```

0 commit comments

Comments
 (0)