Skip to content
40 changes: 40 additions & 0 deletions docs/source/image_dataset.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,46 @@ Upload your dataset with [`~datasets.DatasetDict.push_to_hub`]:
>>> dataset.push_to_hub("stevhliu/my-image-captioning-dataset")
```

## WebDataset

The [Webdataset](https://github.com/webdataset/webdataset) format is based on TAR archives and is suitable for big image datasets.
Indeed you can group your images in TAR archives (e.g. 1GB of images per TAR archive) and have thousands of TAR archives:

```
folder/train/00000.tar
folder/train/00001.tar
folder/train/00002.tar
...
```

In the archives, each example is made of files sharing the same prefix:

```
e39871fd9fd74f55.jpg
e39871fd9fd74f55.json
f18b91585c4d3f3e.jpg
f18b91585c4d3f3e.json
ede6e66b2fb59aab.jpg
ede6e66b2fb59aab.json
ed600d57fcee4f94.jpg
ed600d57fcee4f94.json
...
```

You can put your images labels/captions/bounding boxes using JSON or text files for example.

For more details on the Webdataset format and the python library, please check the [Webdataset documentation](https://webdataset.github.io/webdataset).

If you have `webdataset` installed, load your Webdataset and it will create on column per file suffix (here "jpg" and "json"):

```python
>>> from datasets import load_dataset

>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", split="train")
>>> dataset[0]["json"]
{"bbox": [[302.0, 109.0, 73.0, 52.0]], "categories": [0]}
```

## Loading script

Write a dataset loading script to share a dataset. It defines a dataset's splits and configurations, and handles downloading and generating a dataset. The script is located in the same folder or repository as the dataset and should have the same name.
Expand Down
17 changes: 15 additions & 2 deletions docs/source/image_load.mdx
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Load image data

Image datasets are loaded from the `image` column, which contains a PIL object.
Image datasets have [`Image`] type columns, which contain PIL objects.

<Tip>

To work with image datasets, you need to have the `vision` dependency installed. Check out the [installation](./installation#vision) guide to learn how to install it.

</Tip>

When you load an image dataset and call the `image` column, the [`Image`] feature automatically decodes the PIL object into an image:
When you load an image dataset and call the image column, the images are decoded as PIL Images:

```py
>>> from datasets import load_dataset, Image
Expand Down Expand Up @@ -93,3 +93,16 @@ To ignore the information in the metadata file, set `drop_labels=False` in [`loa
For more information about creating your own `ImageFolder` dataset, take a look at the [Create an image dataset](./image_dataset) guide.

</Tip>

## WebDataset

The [Webdataset](https://github.com/webdataset/webdataset) format is based on a folder of TAR archives and is suitable for big image datasets.
Because of their size, Webdatasets are generally loaded in streaming mode (using `streaming=True`).

If you have `webdataset` installed you can load a Webdataset like this:

```python
>>> from datasets import load_dataset

>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
```
25 changes: 25 additions & 0 deletions docs/source/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,31 @@ For more details, check out the [how to load tabular datasets from SQL databases

</Tip>

### WebDataset

The [Webdataset](https://github.com/webdataset/webdataset) format is based on TAR archives and is suitable for big image datasets.
Because of their size, Webdatasets are generally loaded in streaming mode (using `streaming=True`).

If you have `webdataset` installed you can load a Webdataset like this:

```python
>>> from datasets import load_dataset
>>>
>>> path = "path/to/train/*.tar"
>>> dataset = load_dataset("webdataset", data_files={"train": path}, split="train", streaming=True)
```

To load remote Webdatasets via HTTP, pass the URLs instead:

```python
>>> from datasets import load_dataset
>>>
>>> base_url = "http://storage.googleapis.com/nvdata-openimages/openimages-train-{i:06d}.tar"
>>> urls = [base_url.format(i=i) for i in range(10)]
>>> dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
```


## Multiprocessing

When a dataset is made of several files (that we call "shards"), it is possible to significantly speed up the dataset downloading and preparation step.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/package_reference/loading_methods.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t")
[[autodoc]] datasets.packaged_modules.audiofolder.AudioFolderConfig

[[autodoc]] datasets.packaged_modules.audiofolder.AudioFolder

### Webdataset

[[autodoc]] datasets.packaged_modules.webdataset.Webdataset
2 changes: 2 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
LZ4_AVAILABLE = importlib.util.find_spec("lz4") is not None
PY7ZR_AVAILABLE = importlib.util.find_spec("py7zr") is not None

# Optional data formats
WDS_AVAILABLE = importlib.util.find_spec("webdataset") is not None

# Cache location
DEFAULT_XDG_CACHE_HOME = "~/.cache"
Expand Down
6 changes: 4 additions & 2 deletions src/datasets/download/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _prepare_single_hop_path_and_storage_options(
return urlpath, storage_options


def xopen(file: str, mode="r", *args, download_config: Optional[DownloadConfig] = None, **kwargs):
def xopen(file: str, mode="r", buffering=-1, *args, download_config: Optional[DownloadConfig] = None, **kwargs):
"""Extend `open` function to support remote files using `fsspec`.

It also has a retry mechanism in case connection fails.
Expand All @@ -492,8 +492,10 @@ def xopen(file: str, mode="r", *args, download_config: Optional[DownloadConfig]
# add headers and cookies for authentication on the HF Hub and for Google Drive
file, storage_options = _prepare_path_and_storage_options(file_str, download_config=download_config)
kwargs = {**kwargs, **(storage_options or {})}
if buffering > 1:
kwargs["blocksize"] = buffering
try:
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
file_obj = fsspec.open(file, mode, *args, **kwargs).open()
except ValueError as e:
if str(e) == "Cannot seek streaming HTTP file":
raise NonStreamableDatasetError(
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .parquet import parquet
from .sql import sql # noqa F401
from .text import text
from .webdataset import webdataset


def _hash_python_lines(lines: List[str]) -> str:
Expand All @@ -37,6 +38,7 @@ def _hash_python_lines(lines: List[str]) -> str:
"text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())),
"imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())),
"audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())),
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
}

# Used to infer the module to use based on the data files extensions
Expand All @@ -48,6 +50,7 @@ def _hash_python_lines(lines: List[str]) -> str:
".parquet": ("parquet", {}),
".arrow": ("arrow", {}),
".txt": ("text", {}),
".tar": ("webdataset", {}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make the module inference more robust by inspecting the contents of TAR archives (e.g., consecutive files with the same name (stem) but different extensions)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some webdatasets may contain only one field (e.g. only images) so I'm not sure it would make sense.

Also I like the idea of keeping the TAR loading simple and only support webdataset for TAR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this would break some existing repos on the Hub, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TAR has never been supported on the Hub, what would it break ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you mean that some TAR datasets not in webdataset format won't work properly.

Maybe I can add an error message if the first webdataset examples don't have the same type

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I thought this could fail on

elif ext == ".zip":
return infer_module_for_data_files_list_in_archives(data_files_list, download_config=download_config)

But this logic only inspects ZIP (and ignores TAR) archives :)

}
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
Expand Down
Empty file.
211 changes: 211 additions & 0 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import re
from itertools import islice
from typing import List

import numpy as np
import pyarrow as pa
from packaging import version

import datasets
from datasets import config


logger = datasets.utils.logging.get_logger(__name__)


class Webdataset(datasets.GeneratorBasedBuilder):
DEFAULT_WRITER_BATCH_SIZE = 100
IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script
ENABLED_BASIC_HANDLERS: List[str] # definition at the bottom of the script

def _basic_handlers(self, key, data):
if not config.WDS_AVAILABLE:
raise ImportError("Please install 'webdataset' to load this dataset.")

from webdataset.autodecode import decoders

extension = re.sub(r".*[.]", "", key)
if extension in decoders and extension in self.ENABLED_BASIC_HANDLERS:
return decoders[extension](data)
return None

def _info(self) -> datasets.DatasetInfo:
return datasets.DatasetInfo()

def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
if not config.WDS_AVAILABLE:
raise ImportError("Please install 'webdataset' to load this dataset.")

import webdataset as wds

# Use the extended `open` to read hf:// files
wds.gopen_schemes["hf"] = open

# Download the data files
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
data_files = dl_manager.download(self.config.data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
files = [files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))

# Get one example to get the feature types
pipeline = wds.DataPipeline(
wds.SimpleShardList(files[:1]), wds.tarfile_to_samples(), wds.decode(post=[self._basic_handlers])
)
first_examples = list(islice(pipeline, 5))
if any(example.keys() != first_examples[0].keys() for example in first_examples):
raise ValueError(
"The TAR archives of the dataset should be in Webdataset format, "
"but the files in the archive don't share the same prefix or the same types."
)
inferred_arrow_schema = pa.Table.from_pylist(first_examples[:1]).schema
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)

# Set Image types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also do the same for the Audio feature, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct, that's for another PR :p

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, feel free to add a TO-DO.

for key in first_examples[0]:
extension = re.sub(r".*[.]", "", key)
if extension in self.IMAGE_EXTENSIONS:
features[key] = datasets.Image()
self.info.features = features

return splits

def _generate_examples(self, files):
if not config.WDS_AVAILABLE:
raise ImportError("Please install 'webdataset' to load this dataset.")

import webdataset as wds

image_keys = [key for key, feature in self.info.features.items() if isinstance(feature, datasets.Image)]

for file_idx, file in enumerate(files):
pipeline = wds.DataPipeline(
wds.SimpleShardList(file),
wds.tarfile_to_samples(),
wds.decode(post=[self._basic_handlers]),
)
for example_idx, example in enumerate(pipeline):
for key in image_keys:
example[key] = {"path": example["__key__"] + "." + key, "bytes": example[key]}
yield f"{file_idx}_{example_idx}", example


# Obtained with:
# ```
# import PIL.Image
# IMAGE_EXTENSIONS = []
# PIL.Image.init()
# for ext, format in PIL.Image.EXTENSION.items():
# if format in PIL.Image.OPEN:
# IMAGE_EXTENSIONS.append(ext[1:])
# ```
# We intentionally do not run this code on launch because:
# (1) Pillow is an optional dependency, so importing Pillow in global namespace is not allowed
# (2) To ensure the list of supported extensions is deterministic
IMAGE_EXTENSIONS = [
"blp",
"bmp",
"dib",
"bufr",
"cur",
"pcx",
"dcx",
"dds",
"ps",
"eps",
"fit",
"fits",
"fli",
"flc",
"ftc",
"ftu",
"gbr",
"gif",
"grib",
"h5",
"hdf",
"png",
"apng",
"jp2",
"j2k",
"jpc",
"jpf",
"jpx",
"j2c",
"icns",
"ico",
"im",
"iim",
"tif",
"tiff",
"jfif",
"jpe",
"jpg",
"jpeg",
"mpg",
"mpeg",
"msp",
"pcd",
"pxr",
"pbm",
"pgm",
"ppm",
"pnm",
"psd",
"bw",
"rgb",
"rgba",
"sgi",
"ras",
"tga",
"icb",
"vda",
"vst",
"webp",
"wmf",
"emf",
"xbm",
"xpm",
]
Webdataset.IMAGE_EXTENSIONS = IMAGE_EXTENSIONS


# Obtained by checking `decoders` in `webdataset.autodecode`
# and removing unsafe extension decoders.
# Removed Pickle decoders:
# - "pyd": lambda data: pickle.loads(data)
# - "pickle": lambda data: pickle.loads(data)
# Removed Torch decoders:
# - "pth": lambda data: torch_loads(data)
# Removed NumPy decoders for numpy < 1.16.3 (CVE-2019-6446):
# - "npy": npy_loads,
# - "npz": lambda data: np.load(io.BytesIO(data)),
ENABLED_BASIC_HANDLERS = [
"txt",
"text",
"transcript",
"cls",
"cls2",
"index",
"inx",
"id",
"json",
"jsn",
"ten",
"tb",
"mp",
"msg",
"cbor",
]
if version.parse(np.__version__) >= version.parse("1.16.3"):
ENABLED_BASIC_HANDLERS.extend(["npy", "npz"])
Webdataset.ENABLED_BASIC_HANDLERS = ENABLED_BASIC_HANDLERS
Loading