Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
title: Load image data
- local: image_process
title: Process image data
- local: image_classification
title: Image classification
- local: object_detection
title: Object detection
title: "Vision"
- sections:
- local: nlp_load
Expand Down
81 changes: 81 additions & 0 deletions docs/source/image_classification.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Image classification

Image classification datasets are used to train a model to classify an entire image. There are a wide variety of applications enabled by these datasets such as identifying endangered wildlife species or screening for disease in medical images. This guide will show you how to apply transformations to an image classification dataset.

Before you start, make sure you have up-to-date versions of `albumentations` and `cv2` installed:

```bash
pip install -U albumentations opencv-python
```

This guide uses the [Beans](https://huggingface.co/datasets/beans) dataset for identifying the type of disease of a bean plant based on an image of its leaf.

Load the dataset and take a look at an example:

```py
>>> from datasets import load_dataset

>>> dataset = load_dataset("beans")
>>> dataset["train"][10]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x7F8D2F4D7A10>,
'image_file_path': '/root/.cache/huggingface/datasets/downloads/extracted/b0a21163f78769a2cf11f58dfc767fb458fc7cea5c05dccc0144a2c0f0bc1292/train/angular_leaf_spot/angular_leaf_spot_train.204.jpg',
'labels': 0}
```

The dataset has three fields:

* `image`: a PIL image object.
* `image_file_path`: the path to the image file.
* `labels`: the label or category of the image.

Next, check out an image:

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/img_clf.png">
</div>

Now apply some augmentations with `albumentations`. You'll randomly crop the image, flip it horizontally, and adjust its brightness.

```py
>>> import cv2
>>> import albumentations as A
>>> import numpy as np

>>> transform = A.Compose([
... A.RandomCrop(width=256, height=256),
... A.HorizontalFlip(p=0.5),
... A.RandomBrightnessContrast(p=0.2),
... ])
```

Create a function to apply the transformation to the images:

```py
>>> def transforms(examples):
... examples["pixel_values"] = [
... transform(image=np.array(image))["image"] for image in examples["image"]
... ]
...
... return examples
```

Use the [`~Dataset.set_transform`] function to apply the transformation on-the-fly to batches of the dataset to consume less disk space:

```py
>>> dataset.set_transform(transforms)
```

You can verify the transformation worked by indexing into the `pixel_values` of the first example:

```py
>>> import numpy as np
>>> import matplotlib.pyplot as plt

>>> img = dataset["train"][0]["pixel_values"]
>>> plt.imshow(img)
```

<div class="flex justify-center">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/img_clf_aug.png">
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/img_clf_aug.png"/>
</div>
201 changes: 8 additions & 193 deletions docs/source/image_process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This guide shows specific methods for processing image datasets. Learn how to:

- Use [`~Dataset.map`] with image dataset.
- Apply data augmentations to your dataset with [`~Dataset.set_transform`].
- Apply data augmentations to a dataset with [`~Dataset.set_transform`].

For a guide on how to process any type of dataset, take a look at the <a class="underline decoration-sky-400 decoration-2 font-semibold" href="./process">general process guide</a>.

Expand Down Expand Up @@ -37,21 +37,11 @@ The cache file saves time because you don't have to execute the same transform t

Both parameter values default to 1000, which can be expensive if you are storing images. Lower these values to use less memory when you use [`~Dataset.map`].

## Data augmentation
## Apply transforms

🤗 Datasets can apply data augmentations from any library or package to your dataset.
🤗 Datasets can apply data augmentations from any library or package to your dataset. Transforms can be applied on-the-fly on batches of data with [`~Dataset.set_transform`], and consumes less disk space.

### Image Classification

First let's see how you can transform image classification datasets. This guide will use the transforms from [torchvision](https://pytorch.org/vision/stable/transforms.html).

<Tip>

Feel free to use other data augmentation libraries like [Albumentations](https://albumentations.ai/docs/), [Kornia](https://kornia.readthedocs.io/en/latest/), and [imgaug](https://imgaug.readthedocs.io/en/latest/).

</Tip>

As an example, try to apply a [`ColorJitter`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ColorJitter) transform to change the color properties of the image randomly:
For example, if you'd like to change the color properties of an image randomly:

```py
>>> from torchvision.transforms import Compose, ColorJitter, ToTensor
Expand All @@ -64,191 +54,16 @@ As an example, try to apply a [`ColorJitter`](https://pytorch.org/vision/stable/
... )
```

Create a function to apply the `ColorJitter` transform to an image:
Create a function to apply the `ColorJitter` transform:

```py
>>> def transforms(examples):
... examples["pixel_values"] = [jitter(image.convert("RGB")) for image in examples["image"]]
... return examples
```

Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. This function is useful if you only need to access the examples once:

```py
>>> dataset.set_transform(transforms)
```

Now you can take a look at the augmented image by indexing into the `pixel_values`:

```py
>>> import numpy as np
>>> import matplotlib.pyplot as plt

>>> img = dataset[0]["pixel_values"]
>>> plt.imshow(img.permute(1, 2, 0))
```

<div class="flex justify-center">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png">
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_process_jitter.png"/>
</div>

### Object Detection

Object detection models identify something in an image, and object detection datasets are used for applications such as autonomous driving and detecting natural hazards like wildfire. This guide will show you how to apply transformations to an object detection dataset following the [tutorial](https://albumentations.ai/docs/examples/example_bboxes/) from [Albumentations](https://albumentations.ai/docs/).

To run these examples, make sure you have up-to-date versions of `albumentations` and `cv2` installed:

```
pip install -U albumentations opencv-python
```

In this example, you'll use the [`cppe-5`](https://huggingface.co/datasets/cppe-5) dataset for identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic.

Load the dataset and take a look at an example:
Apply the transform with the [`~Dataset.set_transform`] function:

```py
from datasets import load_dataset

>>> ds = load_dataset("cppe-5")
>>> example = ds['train'][0]
>>> example
{'height': 663,
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=943x663 at 0x7FC3DC756250>,
'image_id': 15,
'objects': {'area': [3796, 1596, 152768, 81002],
'bbox': [[302.0, 109.0, 73.0, 52.0],
[810.0, 100.0, 57.0, 28.0],
[160.0, 31.0, 248.0, 616.0],
[741.0, 68.0, 202.0, 401.0]],
'category': [4, 4, 0, 0],
'id': [114, 115, 116, 117]},
'width': 943}
```

The dataset has the following fields:

- `image`: PIL.Image.Image object containing the image.
- `image_id`: The image ID.
- `height`: The image height.
- `width`: The image width.
- `objects`: A dictionary containing bounding box metadata for the objects in the image:
- `id`: The annotation id.
- `area`: The area of the bounding box.
- `bbox`: The object's bounding box (in the [coco](https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/#coco) format).
- `category`: The object's category, with possible values including `Coverall (0)`, `Face_Shield (1)`, `Gloves (2)`, `Goggles (3)` and `Mask (4)`.

You can visualize the `bboxes` on the image using some internal torch utilities. To do that, you will need to reference the [`~datasets.ClassLabel`] feature associated with the category IDs so you can look up the string labels:


```py
>>> import torch
>>> from torchvision.ops import box_convert
>>> from torchvision.utils import draw_bounding_boxes
>>> from torchvision.transforms.functional import pil_to_tensor, to_pil_image

>>> categories = ds['train'].features['objects'].feature['category']

>>> boxes_xywh = torch.tensor(example['objects']['bbox'])
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
>>> labels = [categories.int2str(x) for x in example['objects']['category']]
>>> to_pil_image(
... draw_bounding_boxes(
... pil_to_tensor(example['image']),
... boxes_xyxy,
... colors="red",
... labels=labels,
... )
... )
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example.png">
</div>


With `albumentations`, you can apply transforms that will affect the image while also updating the `bboxes` accordingly. In this case, the image is resized to (480, 480), flipped horizontally, and brightened.

`albumentations` expects the image to be in BGR format, not RGB, so you'll have to convert the image before applying the transform.

```py
>>> import albumentations as A
>>> import numpy as np

>>> transform = A.Compose([
... A.Resize(480, 480),
... A.HorizontalFlip(p=1.0),
... A.RandomBrightnessContrast(p=1.0),
... ], bbox_params=A.BboxParams(format='coco', label_fields=['category']))

>>> # RGB PIL Image -> BGR Numpy array
>>> image = np.flip(np.array(example['image']), -1)
>>> out = transform(
... image=image,
... bboxes=example['objects']['bbox'],
... category=example['objects']['category'],
... )
```

Now when you visualize the result, the image should be flipped, but the `bboxes` should still be in the right places.

```py
>>> image = torch.tensor(out['image']).flip(-1).permute(2, 0, 1)
>>> boxes_xywh = torch.stack([torch.tensor(x) for x in out['bboxes']])
>>> boxes_xyxy = box_convert(boxes_xywh, 'xywh', 'xyxy')
>>> labels = [categories.int2str(x) for x in out['category']]
>>> to_pil_image(
... draw_bounding_boxes(
... image,
... boxes_xyxy,
... colors='red',
... labels=labels
... )
... )
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed.png">
</div>

Create a function to apply the transform to a batch of examples:

```py
>>> def transforms(examples):
... images, bboxes, categories = [], [], []
... for image, objects in zip(examples['image'], examples['objects']):
... image = np.array(image.convert("RGB"))[:, :, ::-1]
... out = transform(
... image=image,
... bboxes=objects['bbox'],
... category=objects['category']
... )
... images.append(torch.tensor(out['image']).flip(-1).permute(2, 0, 1))
... bboxes.append(torch.tensor(out['bboxes']))
... categories.append(out['category'])
... return {'image': images, 'bbox': bboxes, 'category': categories}
```

Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly which consumes less disk space. The randomness of data augmentation may return a different image if you access the same example twice. It is especially useful when training a model for several epochs.

```py
>>> ds['train'].set_transform(transforms)
```

You can verify the transform works by visualizing the 10th example:

```py
>>> example = ds['train'][10]
>>> to_pil_image(
... draw_bounding_boxes(
... example['image'],
... box_convert(example['bbox'], 'xywh', 'xyxy'),
... colors='red',
... labels=[categories.int2str(x) for x in example['category']]
... )
... )
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/visualize_detection_example_transformed_2.png">
</div>
dataset.set_transform(transforms)
```
Loading