Skip to content

Commit 5ebad86

Browse files
committed
update docs
1 parent 1940aa8 commit 5ebad86

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

docs/source/use_with_pytorch.mdx

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,36 @@ To get a single tensor, you must explicitly use the [`Array`] feature type and s
8686
{'data': tensor([0, 0, 1])}
8787
```
8888

89-
However, since it's not possible to convert text data to PyTorch tensors, you can't format a `string` column to PyTorch.
90-
Instead, you can explicitly format certain columns and leave the other columns unformatted:
89+
String and binary objects are unchanged, since PyTorch only supports numbers.
90+
91+
The [`Image`] and [`Audio`] feature types are also supported:
9192

9293
```py
93-
>>> from datasets import Dataset, Features
94-
>>> text = ["foo", "bar"]
95-
>>> data = [0, 1]
96-
>>> ds = Dataset.from_dict({"text": text, "data": data})
97-
>>> ds = ds.with_format("torch", columns=["data"], output_all_columns=True)
98-
>>> ds[:2]
99-
{'data': tensor([0, 1]), 'text': ['foo', 'bar']}
94+
>>> from datasets import Dataset, Features, Audio, Image
95+
>>> data = ["path/to/image.png"]
96+
>>> features = Features({"data": Image()})
97+
>>> ds = Dataset.from_dict({"data": data}, features=features)
98+
>>> ds = ds.with_format("torch")
99+
>>> ds[0]
100+
{'data': tensor([[[255, 215, 106, 255],
101+
[255, 215, 106, 255],
102+
...,
103+
[255, 255, 255, 255],
104+
[255, 255, 255, 255]]], dtype=torch.uint8)}
100105
```
101106

102-
The [`Image`] and [`Audio`] feature types are not supported yet.
107+
```py
108+
>>> from datasets import Dataset, Features, Audio, Image
109+
>>> data = ["path/to/audio.wav"]
110+
>>> features = Features({"data": Audio()})
111+
>>> ds = Dataset.from_dict({"data": data}, features=features)
112+
>>> ds = ds.with_format("torch")
113+
>>> ds[0]["data"]["array"]
114+
tensor([ 6.1035e-05, 1.5259e-05, 1.6785e-04, ..., -1.5259e-05,
115+
-1.5259e-05, 1.5259e-05])
116+
>>> ds[0]["data"]["sampling_rate"]
117+
tensor(44100)
118+
```
103119

104120
## Data loading
105121

docs/source/use_with_tensorflow.mdx

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ To get a single tensor, you must explicitly use the Array feature type and speci
8989
{'data': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>
9090
```
9191

92-
Strings are also supported:
92+
Strings and binary objects are also supported:
9393

9494
```py
9595
>>> from datasets import Dataset, Features
@@ -111,7 +111,38 @@ You can also explicitly format certain columns and leave the other columns unfor
111111
'text': ['foo', 'bar']}
112112
```
113113

114-
The [`Image`] and [`Audio`] feature types are not supported yet.
114+
String and binary objects are unchanged, since PyTorch only supports numbers.
115+
116+
The [`Image`] and [`Audio`] feature types are also supported:
117+
118+
```py
119+
>>> from datasets import Dataset, Features, Audio, Image
120+
>>> data = ["path/to/image.png"]
121+
>>> features = Features({"data": Image()})
122+
>>> ds = Dataset.from_dict({"data": data}, features=features)
123+
>>> ds = ds.with_format("tf")
124+
>>> ds[0]
125+
{'data': <tf.Tensor: shape=(215, 1200, 4), dtype=uint8, numpy=
126+
array([[[255, 215, 106, 255],
127+
[255, 215, 106, 255],
128+
...,
129+
[255, 255, 255, 255],
130+
[255, 255, 255, 255]]], dtype=uint8)>}
131+
```
132+
133+
```py
134+
>>> from datasets import Dataset, Features, Audio, Image
135+
>>> data = ["path/to/audio.wav"]
136+
>>> features = Features({"data": Audio()})
137+
>>> ds = Dataset.from_dict({"data": data}, features=features)
138+
>>> ds = ds.with_format("tf")
139+
>>> ds[0]["data"]["array"]
140+
<tf.Tensor: shape=(202311,), dtype=float32, numpy=
141+
array([ 6.1035156e-05, 1.5258789e-05, 1.6784668e-04, ...,
142+
-1.5258789e-05, -1.5258789e-05, 1.5258789e-05], dtype=float32)>
143+
>>> ds[0]["data"]["sampling_rate"]
144+
<tf.Tensor: shape=(), dtype=int32, numpy=44100>
145+
```
115146

116147
## Data loading
117148

0 commit comments

Comments
 (0)