Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions docs/source/use_dataset.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,63 @@ Most image models expect the image to be in the RGB mode. The Beans images are a
>>> dataset = dataset.cast_column("image", Image(mode="RGB"))
```

**3**. Now, you can apply some transforms to the image. Feel free to take a look at the [various transforms available](https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py) in torchvision and choose one you'd like to experiment with. This example applies a transform that randomly rotates the image:
**3**. Now let's apply data augmentations to your images. 🤗 Datasets works with any augmentation library, and in this example we'll use Albumentations.

### Using Albumentations

[Albumentations](https://albumentations.ai) is a popular image augmentation library that provides a [rich set of transforms](https://albumentations.ai/docs/reference/supported-targets-by-transform/) including spatial-level transforms, pixel-level transforms, and mixing-level transforms. When running on CPU, which is typical for transformers pipelines, Albumentations is [faster than torchvision](https://albumentations.ai/docs/benchmarks/image-benchmarks/).

Install Albumentations:

```bash
pip install albumentations
```

**4**. Create a typical augmentation pipeline with Albumentations:

```py
>>> from torchvision.transforms import RandomRotation
>>> import albumentations as A
>>> import numpy as np
>>> from PIL import Image

>>> transform = A.Compose([
... A.RandomCrop(height=256, width=256, pad_if_needed=True, p=1),
... A.HorizontalFlip(p=0.5),
... A.ColorJitter(p=0.5)
... ])
```

**5**. Since 🤗 Datasets uses PIL images but Albumentations expects OpenCV format (numpy arrays), you need to convert between formats:

>>> rotate = RandomRotation(degrees=(0, 90))
>>> def transforms(examples):
... examples["pixel_values"] = [rotate(image) for image in examples["image"]]
```py
>>> def albumentations_transforms(examples):
... # Apply Albumentations transforms
... transformed_images = []
... for image in examples["image"]:
... # Convert PIL to numpy array (OpenCV format)
... image_np = np.array(image.convert("RGB"))
...
... # Apply Albumentations transforms
... transformed_image = transform(image=image_np)["image"]
...
... # Convert back to PIL Image
... pil_image = Image.fromarray(transformed_image)
... transformed_images.append(pil_image)
...
... examples["pixel_values"] = transformed_images
... return examples
```

**4**. Use the [`~Dataset.set_transform`] function to apply the transform on-the-fly. When you index into the image `pixel_values`, the transform is applied, and your image gets rotated.
**6**. Apply the transform using [`~Dataset.set_transform`]:

```py
>>> dataset.set_transform(transforms)
>>> dataset.set_transform(albumentations_transforms)
>>> dataset[0]["pixel_values"]
```

**5**. The dataset is now ready for training with your machine learning framework!
**Key points when using Albumentations with 🤗 Datasets:**
- Convert PIL images to numpy arrays before applying transforms
- Albumentations returns a dictionary with the transformed image under the "image" key
- Convert the result back to PIL format after transformation

**7**. The dataset is now ready for training with your machine learning framework!