-
Notifications
You must be signed in to change notification settings - Fork 3k
Webdataset dataset builder #6391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
ba4d055
179c3a8
50551dc
b156345
aca2f87
a451942
79a0a6d
d9c08a5
402f64a
76ff9e5
54f98b0
1b9ce2c
3420d2a
c6a32cc
d3e0dfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||
| from .parquet import parquet | ||||||
| from .sql import sql # noqa F401 | ||||||
| from .text import text | ||||||
| from .webdataset import webdataset | ||||||
|
|
||||||
|
|
||||||
| def _hash_python_lines(lines: List[str]) -> str: | ||||||
|
|
@@ -37,6 +38,7 @@ def _hash_python_lines(lines: List[str]) -> str: | |||||
| "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), | ||||||
| "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), | ||||||
| "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), | ||||||
| "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), | ||||||
| } | ||||||
|
|
||||||
| # Used to infer the module to use based on the data files extensions | ||||||
|
|
@@ -48,6 +50,7 @@ def _hash_python_lines(lines: List[str]) -> str: | |||||
| ".parquet": ("parquet", {}), | ||||||
| ".arrow": ("arrow", {}), | ||||||
| ".txt": ("text", {}), | ||||||
| ".tar": ("webdataset", {}), | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can make the module inference more robust by inspecting the contents of TAR archives (e.g., consecutive files with the same name (stem) but different extensions)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some webdatasets may contain only one field (e.g. only images) so I'm not sure it would make sense. Also I like the idea of keeping the TAR loading simple and only support webdataset for TAR
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this would break some existing repos on the Hub, no?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TAR has never been supported on the Hub, what would it break ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah you mean that some TAR datasets not in webdataset format won't work properly. Maybe I can add an error message if the first webdataset examples don't have the same type
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind, I thought this could fail on Lines 449 to 450 in 27d1fe5
But this logic only inspects ZIP (and ignores TAR) archives :) |
||||||
| } | ||||||
| _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) | ||||||
| _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,211 @@ | ||
| import re | ||
| from itertools import islice | ||
| from typing import List | ||
|
|
||
| import numpy as np | ||
| import pyarrow as pa | ||
| from packaging import version | ||
|
|
||
| import datasets | ||
| from datasets import config | ||
|
|
||
|
|
||
| logger = datasets.utils.logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class Webdataset(datasets.GeneratorBasedBuilder): | ||
lhoestq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| DEFAULT_WRITER_BATCH_SIZE = 100 | ||
| IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script | ||
| ENABLED_BASIC_HANDLERS: List[str] # definition at the bottom of the script | ||
|
|
||
| def _basic_handlers(self, key, data): | ||
| if not config.WDS_AVAILABLE: | ||
| raise ImportError("Please install 'webdataset' to load this dataset.") | ||
|
|
||
| from webdataset.autodecode import decoders | ||
mariosasko marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| extension = re.sub(r".*[.]", "", key) | ||
| if extension in decoders and extension in self.ENABLED_BASIC_HANDLERS: | ||
| return decoders[extension](data) | ||
| return None | ||
|
|
||
| def _info(self) -> datasets.DatasetInfo: | ||
| return datasets.DatasetInfo() | ||
|
|
||
| def _split_generators(self, dl_manager): | ||
| """We handle string, list and dicts in datafiles""" | ||
| if not config.WDS_AVAILABLE: | ||
| raise ImportError("Please install 'webdataset' to load this dataset.") | ||
|
|
||
| import webdataset as wds | ||
|
|
||
| # Use the extended `open` to read hf:// files | ||
| wds.gopen_schemes["hf"] = open | ||
|
|
||
| # Download the data files | ||
| if not self.config.data_files: | ||
| raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") | ||
| data_files = dl_manager.download(self.config.data_files) | ||
| if isinstance(data_files, (str, list, tuple)): | ||
| files = data_files | ||
| if isinstance(files, str): | ||
| files = [files] | ||
| return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] | ||
| splits = [] | ||
| for split_name, files in data_files.items(): | ||
| if isinstance(files, str): | ||
| files = [files] | ||
| splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) | ||
|
|
||
| # Get one example to get the feature types | ||
| pipeline = wds.DataPipeline( | ||
| wds.SimpleShardList(files[:1]), wds.tarfile_to_samples(), wds.decode(post=[self._basic_handlers]) | ||
| ) | ||
| first_examples = list(islice(pipeline, 5)) | ||
mariosasko marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if any(example.keys() != first_examples[0].keys() for example in first_examples): | ||
| raise ValueError( | ||
| "The TAR archives of the dataset should be in Webdataset format, " | ||
| "but the files in the archive don't share the same prefix or the same types." | ||
| ) | ||
| inferred_arrow_schema = pa.Table.from_pylist(first_examples[:1]).schema | ||
| features = datasets.Features.from_arrow_schema(inferred_arrow_schema) | ||
|
|
||
| # Set Image types | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also do the same for the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes correct, that's for another PR :p
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, feel free to add a TO-DO. |
||
| for key in first_examples[0]: | ||
| extension = re.sub(r".*[.]", "", key) | ||
lhoestq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if extension in self.IMAGE_EXTENSIONS: | ||
| features[key] = datasets.Image() | ||
| self.info.features = features | ||
|
|
||
| return splits | ||
|
|
||
| def _generate_examples(self, files): | ||
| if not config.WDS_AVAILABLE: | ||
| raise ImportError("Please install 'webdataset' to load this dataset.") | ||
|
|
||
| import webdataset as wds | ||
|
|
||
| image_keys = [key for key, feature in self.info.features.items() if isinstance(feature, datasets.Image)] | ||
|
|
||
| for file_idx, file in enumerate(files): | ||
| pipeline = wds.DataPipeline( | ||
| wds.SimpleShardList(file), | ||
| wds.tarfile_to_samples(), | ||
| wds.decode(post=[self._basic_handlers]), | ||
| ) | ||
| for example_idx, example in enumerate(pipeline): | ||
| for key in image_keys: | ||
| example[key] = {"path": example["__key__"] + "." + key, "bytes": example[key]} | ||
| yield f"{file_idx}_{example_idx}", example | ||
|
|
||
|
|
||
| # Obtained with: | ||
| # ``` | ||
| # import PIL.Image | ||
| # IMAGE_EXTENSIONS = [] | ||
| # PIL.Image.init() | ||
| # for ext, format in PIL.Image.EXTENSION.items(): | ||
| # if format in PIL.Image.OPEN: | ||
| # IMAGE_EXTENSIONS.append(ext[1:]) | ||
| # ``` | ||
| # We intentionally do not run this code on launch because: | ||
| # (1) Pillow is an optional dependency, so importing Pillow in global namespace is not allowed | ||
| # (2) To ensure the list of supported extensions is deterministic | ||
| IMAGE_EXTENSIONS = [ | ||
| "blp", | ||
| "bmp", | ||
| "dib", | ||
| "bufr", | ||
| "cur", | ||
| "pcx", | ||
| "dcx", | ||
| "dds", | ||
| "ps", | ||
| "eps", | ||
| "fit", | ||
| "fits", | ||
| "fli", | ||
| "flc", | ||
| "ftc", | ||
| "ftu", | ||
| "gbr", | ||
| "gif", | ||
| "grib", | ||
| "h5", | ||
| "hdf", | ||
| "png", | ||
| "apng", | ||
| "jp2", | ||
| "j2k", | ||
| "jpc", | ||
| "jpf", | ||
| "jpx", | ||
| "j2c", | ||
| "icns", | ||
| "ico", | ||
| "im", | ||
| "iim", | ||
| "tif", | ||
| "tiff", | ||
| "jfif", | ||
| "jpe", | ||
| "jpg", | ||
| "jpeg", | ||
| "mpg", | ||
| "mpeg", | ||
| "msp", | ||
| "pcd", | ||
| "pxr", | ||
| "pbm", | ||
| "pgm", | ||
| "ppm", | ||
| "pnm", | ||
| "psd", | ||
| "bw", | ||
| "rgb", | ||
| "rgba", | ||
| "sgi", | ||
| "ras", | ||
| "tga", | ||
| "icb", | ||
| "vda", | ||
| "vst", | ||
| "webp", | ||
| "wmf", | ||
| "emf", | ||
| "xbm", | ||
| "xpm", | ||
| ] | ||
| Webdataset.IMAGE_EXTENSIONS = IMAGE_EXTENSIONS | ||
|
|
||
|
|
||
| # Obtained by checking `decoders` in `webdataset.autodecode` | ||
| # and removing unsafe extension decoders. | ||
| # Removed Pickle decoders: | ||
| # - "pyd": lambda data: pickle.loads(data) | ||
| # - "pickle": lambda data: pickle.loads(data) | ||
| # Removed Torch decoders: | ||
| # - "pth": lambda data: torch_loads(data) | ||
| # Removed NumPy decoders for numpy < 1.16.3 (CVE-2019-6446): | ||
| # - "npy": npy_loads, | ||
| # - "npz": lambda data: np.load(io.BytesIO(data)), | ||
| ENABLED_BASIC_HANDLERS = [ | ||
| "txt", | ||
| "text", | ||
| "transcript", | ||
| "cls", | ||
| "cls2", | ||
| "index", | ||
| "inx", | ||
| "id", | ||
| "json", | ||
| "jsn", | ||
| "ten", | ||
| "tb", | ||
| "mp", | ||
| "msg", | ||
| "cbor", | ||
| ] | ||
| if version.parse(np.__version__) >= version.parse("1.16.3"): | ||
| ENABLED_BASIC_HANDLERS.extend(["npy", "npz"]) | ||
| Webdataset.ENABLED_BASIC_HANDLERS = ENABLED_BASIC_HANDLERS | ||
Uh oh!
There was an error while loading. Please reload this page.