Skip to content

Commit 1b935da

Browse files
authored
Image & Audio formatting for numpy/torch/tf/jax (#5072)
* wip * np formatting * jax formatting * tf formatting * torch formatting * support str and bytes in jax * support str and bytes in torch * add tests * torch/tf/jax: support None * update docs * fix tests * don't apply dtype in numpy extractor * add consolidation * update tests * fix consolidate for np and jax * only rag 1-d tensors in tf * update tests * fix tf export * don't convert np numbers to arrays * update tests * docs: batched tensors * use object dtype for ragged numpy arrays * update tests * fix np.array(..., object)
1 parent 6be722d commit 1b935da

File tree

15 files changed

+606
-165
lines changed

15 files changed

+606
-165
lines changed

docs/source/use_with_pytorch.mdx

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,28 +78,54 @@ To get a single tensor, you must explicitly use the [`Array`] feature type and s
7878

7979
```py
8080
>>> from datasets import Dataset, Features, ClassLabel
81-
>>> data = [0, 0, 1]
82-
>>> features = Features({"data": ClassLabel(names=["negative", "positive"])})
83-
>>> ds = Dataset.from_dict({"data": data}, features=features)
81+
>>> labels = [0, 0, 1]
82+
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
83+
>>> ds = Dataset.from_dict({"label": labels}, features=features)
8484
>>> ds = ds.with_format("torch")
8585
>>> ds[:3]
86-
{'data': tensor([0, 0, 1])}
86+
{'label': tensor([0, 0, 1])}
8787
```
8888

89-
However, since it's not possible to convert text data to PyTorch tensors, you can't format a `string` column to PyTorch.
90-
Instead, you can explicitly format certain columns and leave the other columns unformatted:
89+
String and binary objects are unchanged, since PyTorch only supports numbers.
90+
91+
The [`Image`] and [`Audio`] feature types are also supported:
9192

9293
```py
93-
>>> from datasets import Dataset, Features
94-
>>> text = ["foo", "bar"]
95-
>>> data = [0, 1]
96-
>>> ds = Dataset.from_dict({"text": text, "data": data})
97-
>>> ds = ds.with_format("torch", columns=["data"], output_all_columns=True)
98-
>>> ds[:2]
99-
{'data': tensor([0, 1]), 'text': ['foo', 'bar']}
94+
>>> from datasets import Dataset, Features, Audio, Image
95+
>>> images = ["path/to/image.png"] * 10
96+
>>> features = Features({"image": Image()})
97+
>>> ds = Dataset.from_dict({"image": images}, features=features)
98+
>>> ds = ds.with_format("torch")
99+
>>> ds[0]["image"].shape
100+
torch.Size([512, 512, 4])
101+
>>> ds[0]
102+
{'image': tensor([[[255, 215, 106, 255],
103+
[255, 215, 106, 255],
104+
...,
105+
[255, 255, 255, 255],
106+
[255, 255, 255, 255]]], dtype=torch.uint8)}
107+
>>> ds[:2]["image"].shape
108+
torch.Size([2, 512, 512, 4])
109+
>>> ds[:2]
110+
{'image': tensor([[[[255, 215, 106, 255],
111+
[255, 215, 106, 255],
112+
...,
113+
[255, 255, 255, 255],
114+
[255, 255, 255, 255]]]], dtype=torch.uint8)}
100115
```
101116

102-
The [`Image`] and [`Audio`] feature types are not supported yet.
117+
```py
118+
>>> from datasets import Dataset, Features, Audio, Image
119+
>>> audio = ["path/to/audio.wav"] * 10
120+
>>> features = Features({"audio": Audio()})
121+
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
122+
>>> ds = ds.with_format("torch")
123+
>>> ds[0]["audio"]["array"]
124+
tensor([ 6.1035e-05, 1.5259e-05, 1.6785e-04, ..., -1.5259e-05,
125+
-1.5259e-05, 1.5259e-05])
126+
>>> ds[0]["audio"]["sampling_rate"]
127+
tensor(44100)
128+
```
103129

104130
## Data loading
105131

docs/source/use_with_tensorflow.mdx

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ To get a single tensor, you must explicitly use the Array feature type and speci
8181

8282
```py
8383
>>> from datasets import Dataset, Features, ClassLabel
84-
>>> data = [0, 0, 1]
85-
>>> features = Features({"data": ClassLabel(names=["negative", "positive"])})
86-
>>> ds = Dataset.from_dict({"data": data}, features=features)
84+
>>> labels = [0, 0, 1]
85+
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
86+
>>> ds = Dataset.from_dict({"label": labels}, features=features)
8787
>>> ds = ds.with_format("tf")
8888
>>> ds[:3]
89-
{'data': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>
89+
{'label': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>}
9090
```
9191

92-
Strings are also supported:
92+
Strings and binary objects are also supported:
9393

9494
```py
9595
>>> from datasets import Dataset, Features
@@ -111,7 +111,45 @@ You can also explicitly format certain columns and leave the other columns unfor
111111
'text': ['foo', 'bar']}
112112
```
113113

114-
The [`Image`] and [`Audio`] feature types are not supported yet.
114+
String and binary objects are unchanged, since PyTorch only supports numbers.
115+
116+
The [`Image`] and [`Audio`] feature types are also supported:
117+
118+
```py
119+
>>> from datasets import Dataset, Features, Audio, Image
120+
>>> images = ["path/to/image.png"] * 10
121+
>>> features = Features({"image": Image()})
122+
>>> ds = Dataset.from_dict({"image": images}, features=features)
123+
>>> ds = ds.with_format("tf")
124+
>>> ds[0]
125+
{'image': <tf.Tensor: shape=(512, 512, 4), dtype=uint8, numpy=
126+
array([[[255, 215, 106, 255],
127+
[255, 215, 106, 255],
128+
...,
129+
[255, 255, 255, 255],
130+
[255, 255, 255, 255]]], dtype=uint8)>}
131+
>>> ds[:2]
132+
{'image': <tf.Tensor: shape=(2, 512, 512, 4), dtype=uint8, numpy=
133+
array([[[[255, 215, 106, 255],
134+
[255, 215, 106, 255],
135+
...,
136+
[255, 255, 255, 255],
137+
[255, 255, 255, 255]]]], dtype=uint8)>}
138+
```
139+
140+
```py
141+
>>> from datasets import Dataset, Features, Audio, Image
142+
>>> audio = ["path/to/audio.wav"] * 10
143+
>>> features = Features({"audio": Audio()})
144+
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
145+
>>> ds = ds.with_format("tf")
146+
>>> ds[0]["audio"]["array"]
147+
<tf.Tensor: shape=(202311,), dtype=float32, numpy=
148+
array([ 6.1035156e-05, 1.5258789e-05, 1.6784668e-04, ...,
149+
-1.5258789e-05, -1.5258789e-05, 1.5258789e-05], dtype=float32)>
150+
>>> ds[0]["audio"]["sampling_rate"]
151+
<tf.Tensor: shape=(), dtype=int32, numpy=44100>
152+
```
115153

116154
## Data loading
117155

src/datasets/arrow_dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3985,9 +3985,14 @@ def _int64_feature(values):
39853985
"""Returns an int64_list from a list of bool / enum / int / uint."""
39863986
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
39873987

3988-
def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature":
3988+
def _feature(values: Union[float, int, str, np.ndarray, list]) -> "tf.train.Feature":
39893989
"""Typechecks `values` and returns the corresponding tf.train.Feature."""
3990-
if isinstance(values, np.ndarray):
3990+
if isinstance(values, list):
3991+
if values and isinstance(values[0], str):
3992+
return _bytes_feature([v.encode() for v in values])
3993+
else:
3994+
raise ValueError(f"values={values} is empty or contains items that cannot be serialized")
3995+
elif isinstance(values, np.ndarray):
39913996
if values.dtype == np.dtype(float):
39923997
return _float_feature(values)
39933998
elif values.dtype == np.int64:
@@ -3998,9 +4003,9 @@ def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature":
39984003
return _bytes_feature([v.encode() for v in values])
39994004
else:
40004005
raise ValueError(
4001-
f"values={values} is an np.ndarray with items of dtype {values[0].dtype}, which cannot be serialized"
4006+
f"values={values} is empty or is an np.ndarray with items of dtype {values[0].dtype}, which cannot be serialized"
40024007
)
4003-
if hasattr(values, "dtype"):
4008+
elif hasattr(values, "dtype"):
40044009
if np.issubdtype(values.dtype, np.floating):
40054010
return _float_feature([values.item()])
40064011
elif np.issubdtype(values.dtype, np.integer):
@@ -4010,7 +4015,7 @@ def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature":
40104015
else:
40114016
raise ValueError(f"values={values} has dtype {values.dtype}, which cannot be serialized")
40124017
else:
4013-
raise ValueError(f"values={values} are not numpy objects, and so cannot be serialized")
4018+
raise ValueError(f"values={values} are not numpy objects or strings, and so cannot be serialized")
40144019

40154020
def serialize_example(ex):
40164021
feature = {key: _feature(value) for key, value in ex.items()}

src/datasets/features/image.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def encode_example(self, value: Union[str, dict, np.ndarray, "PIL.Image.Image"])
8181
else:
8282
raise ImportError("To support encoding images, please install 'Pillow'.")
8383

84+
if isinstance(value, list):
85+
value = np.array(value)
86+
8487
if isinstance(value, str):
8588
return {"path": value, "bytes": None}
8689
elif isinstance(value, np.ndarray):

src/datasets/formatting/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
ArrowFormatter,
2424
CustomFormatter,
2525
Formatter,
26-
NumpyFormatter,
2726
PandasFormatter,
2827
PythonFormatter,
2928
format_table,
3029
query_table,
3130
)
31+
from .np_formatter import NumpyFormatter
3232

3333

3434
logger = logging.get_logger(__name__)

src/datasets/formatting/formatting.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pandas as pd
2222
import pyarrow as pa
2323

24+
from ..features import Features
2425
from ..features.features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper
2526
from ..table import Table
2627
from ..utils.py_utils import no_op_if_value_is_null
@@ -198,8 +199,8 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
198199
or (isinstance(x, float) and np.isnan(x))
199200
for x in array
200201
):
201-
return np.array(array, copy=False, **{**self.np_array_kwargs, "dtype": object})
202-
return np.array(array, copy=False, **self.np_array_kwargs)
202+
return np.array(array, copy=False, dtype=object)
203+
return np.array(array, copy=False)
203204

204205

205206
class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):
@@ -214,7 +215,7 @@ def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame:
214215

215216

216217
class PythonFeaturesDecoder:
217-
def __init__(self, features):
218+
def __init__(self, features: Features):
218219
self.features = features
219220

220221
def decode_row(self, row: dict) -> dict:
@@ -228,7 +229,7 @@ def decode_batch(self, batch: dict) -> dict:
228229

229230

230231
class PandasFeaturesDecoder:
231-
def __init__(self, features):
232+
def __init__(self, features: Features):
232233
self.features = features
233234

234235
def decode_row(self, row: pd.DataFrame) -> pd.DataFrame:
@@ -325,30 +326,6 @@ def format_batch(self, pa_table: pa.Table) -> dict:
325326
return batch
326327

327328

328-
class NumpyFormatter(Formatter[dict, np.ndarray, dict]):
329-
def __init__(self, features=None, decoded=True, **np_array_kwargs):
330-
super().__init__(features=features, decoded=decoded)
331-
self.np_array_kwargs = np_array_kwargs
332-
333-
def format_row(self, pa_table: pa.Table) -> dict:
334-
row = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_row(pa_table)
335-
if self.decoded:
336-
row = self.python_features_decoder.decode_row(row)
337-
return row
338-
339-
def format_column(self, pa_table: pa.Table) -> np.ndarray:
340-
column = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_column(pa_table)
341-
if self.decoded:
342-
column = self.python_features_decoder.decode_column(column, pa_table.column_names[0])
343-
return column
344-
345-
def format_batch(self, pa_table: pa.Table) -> dict:
346-
batch = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_batch(pa_table)
347-
if self.decoded:
348-
batch = self.python_features_decoder.decode_batch(batch)
349-
return batch
350-
351-
352329
class PandasFormatter(Formatter):
353330
def format_row(self, pa_table: pa.Table) -> pd.DataFrame:
354331
row = self.pandas_arrow_extractor().extract_row(pa_table)

src/datasets/formatting/jax_formatter.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414

1515
# Lint as: python3
16+
import sys
1617
from typing import TYPE_CHECKING
1718

1819
import numpy as np
1920
import pyarrow as pa
2021

22+
from .. import config
2123
from ..utils.py_utils import map_nested
2224
from .formatting import Formatter
2325

@@ -28,47 +30,80 @@
2830

2931
class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]):
3032
def __init__(self, features=None, decoded=True, **jnp_array_kwargs):
33+
super().__init__(features=features, decoded=decoded)
3134
self.jnp_array_kwargs = jnp_array_kwargs
3235
import jax.numpy as jnp # noqa import jax at initialization
3336

37+
def _consolidate(self, column):
38+
import jax.numpy as jnp
39+
40+
if isinstance(column, list) and column:
41+
if all(
42+
isinstance(x, jnp.ndarray) and x.shape == column[0].shape and x.dtype == column[0].dtype
43+
for x in column
44+
):
45+
return jnp.stack(column)
46+
return column
47+
3448
def _tensorize(self, value):
3549
import jax
3650
import jax.numpy as jnp
3751

52+
if isinstance(value, (str, bytes, type(None))):
53+
return value
54+
elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
55+
return value.tolist()
56+
3857
default_dtype = {}
39-
if np.issubdtype(value.dtype, np.integer):
58+
59+
if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
4060
# the default int precision depends on the jax config
4161
# see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
4262
if jax.config.jax_enable_x64:
4363
default_dtype = {"dtype": jnp.int64}
4464
else:
4565
default_dtype = {"dtype": jnp.int32}
46-
elif np.issubdtype(value.dtype, np.floating):
66+
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
4767
default_dtype = {"dtype": jnp.float32}
68+
elif config.PIL_AVAILABLE and "PIL" in sys.modules:
69+
import PIL.Image
70+
71+
if isinstance(value, PIL.Image.Image):
72+
value = np.asarray(value)
4873

4974
# calling jnp.array on a np.ndarray does copy the data
5075
# see https://github.com/google/jax/issues/4486
5176
return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})
5277

5378
def _recursive_tensorize(self, data_struct: dict):
5479
# support for nested types like struct of list of struct
55-
if isinstance(data_struct, (list, np.ndarray)):
56-
data_struct = np.array(data_struct, copy=False)
80+
if isinstance(data_struct, np.ndarray):
5781
if data_struct.dtype == object: # jax arrays cannot be instantied from an array of objects
58-
return [self.recursive_tensorize(substruct) for substruct in data_struct]
82+
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
5983
return self._tensorize(data_struct)
6084

6185
def recursive_tensorize(self, data_struct: dict):
62-
return map_nested(self._recursive_tensorize, data_struct, map_list=False)
86+
return map_nested(self._recursive_tensorize, data_struct)
6387

6488
def format_row(self, pa_table: pa.Table) -> dict:
6589
row = self.numpy_arrow_extractor().extract_row(pa_table)
90+
if self.decoded:
91+
row = self.python_features_decoder.decode_row(row)
6692
return self.recursive_tensorize(row)
6793

6894
def format_column(self, pa_table: pa.Table) -> "jnp.ndarray":
69-
col = self.numpy_arrow_extractor().extract_column(pa_table)
70-
return self.recursive_tensorize(col)
95+
column = self.numpy_arrow_extractor().extract_column(pa_table)
96+
if self.decoded:
97+
column = self.python_features_decoder.decode_column(column, pa_table.column_names[0])
98+
column = self.recursive_tensorize(column)
99+
column = self._consolidate(column)
100+
return column
71101

72102
def format_batch(self, pa_table: pa.Table) -> dict:
73103
batch = self.numpy_arrow_extractor().extract_batch(pa_table)
74-
return self.recursive_tensorize(batch)
104+
if self.decoded:
105+
batch = self.python_features_decoder.decode_batch(batch)
106+
batch = self.recursive_tensorize(batch)
107+
for column_name in batch:
108+
batch[column_name] = self._consolidate(batch[column_name])
109+
return batch

0 commit comments

Comments
 (0)