Skip to content

Support numpy/torch/tf/jax formatting for IterableDataset #5083

@lhoestq

Description

@lhoestq

Right now IterableDataset doesn't do any formatting.

In particular this code should return a numpy array:

from datasets import load_dataset

ds = load_dataset("imagenet-1k", split="train", streaming=True).with_format("np")
print(next(iter(ds))["image"])

Right now it returns a PIL.Image.

Setting streaming=False does return a numpy array after #5072

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestgood second issueIssues a bit more difficult than "Good First" issuesstreaming

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions