Skip to content

Commit c412a6f

Browse files
authored
feat: avoid some copies in torch formatter (#7787)
* feat: avoid some copies in torch formatter * fix: handle kwargs * fix: run ruff * fix: handle dtype * fix: handle non writable np arrays * fix: remove comment map_nested * fix: adjust import for lint
1 parent 27c2e70 commit c412a6f

File tree

1 file changed

+182
-40
lines changed

1 file changed

+182
-40
lines changed

src/datasets/formatting/torch_formatter.py

Lines changed: 182 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,90 +21,232 @@
2121
import pyarrow as pa
2222

2323
from .. import config
24-
from ..utils.py_utils import map_nested
2524
from .formatting import TensorFormatter
2625

2726

2827
if TYPE_CHECKING:
2928
import torch
3029

30+
# Import torch once at module level once
31+
try:
32+
import torch
33+
34+
_torch_available = True
35+
except ImportError:
36+
_torch_available = False
37+
torch = None
38+
3139

3240
class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]):
3341
def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs):
3442
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
3543
self.torch_tensor_kwargs = torch_tensor_kwargs
36-
import torch # noqa import torch at initialization
44+
45+
if not _torch_available:
46+
raise ImportError("PyTorch is required but not available")
3747

3848
def _consolidate(self, column):
39-
import torch
40-
41-
if isinstance(column, list) and column:
42-
if all(
43-
isinstance(x, torch.Tensor) and x.shape == column[0].shape and x.dtype == column[0].dtype
44-
for x in column
45-
):
46-
return torch.stack(column)
49+
"""Smarter consolidation that only stacks when safe and beneficial."""
50+
if not isinstance(column, list) or not column:
51+
return column
52+
53+
# Check if all items are tensors with matching properties
54+
first = column[0]
55+
if not isinstance(first, torch.Tensor):
56+
return column
57+
58+
# Fast check: if all tensors have same shape, dtype, and device, we can stack
59+
if all(
60+
isinstance(x, torch.Tensor)
61+
and x.shape == first.shape
62+
and x.dtype == first.dtype
63+
and x.device == first.device
64+
for x in column
65+
):
66+
return torch.stack(column)
67+
4768
return column
4869

4970
def _tensorize(self, value):
50-
import torch
51-
71+
"""Zero/low-copy tensor conversion with smart dtype handling."""
72+
# Fast path for strings, bytes, None
5273
if isinstance(value, (str, bytes, type(None))):
5374
return value
54-
elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
55-
return value.tolist()
56-
57-
default_dtype = {}
5875

59-
if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
60-
default_dtype = {"dtype": torch.int64}
61-
62-
# Convert dtype to np.int64 if it's either np.uint16 or np.uint32 to ensure compatibility.
63-
# np.uint64 is excluded from this conversion as there is no compatible PyTorch dtype that can handle it without loss.
64-
if value.dtype in [np.uint16, np.uint32]:
65-
value = value.astype(np.int64)
66-
67-
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
68-
default_dtype = {"dtype": torch.float32}
76+
# Handle string arrays
77+
if isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
78+
return value.tolist()
6979

80+
# PIL Image fast path - avoid extra copies
7081
if config.PIL_AVAILABLE and "PIL" in sys.modules:
7182
import PIL.Image
7283

7384
if isinstance(value, PIL.Image.Image):
74-
value = np.asarray(value)
75-
if value.ndim == 2:
76-
value = value[:, :, np.newaxis]
85+
# Single conversion path: PIL -> numpy -> torch
86+
arr = np.asarray(value)
87+
if arr.ndim == 2:
88+
arr = arr[:, :, np.newaxis]
89+
# Use moveaxis instead of transpose
90+
arr = np.moveaxis(arr, -1, 0) # HWC -> CHW
91+
# Ensure contiguous for zero-copy conversion
92+
if not arr.flags.c_contiguous:
93+
arr = np.ascontiguousarray(arr)
94+
# Ensure array is writable for torch conversion
95+
if not arr.flags.writeable:
96+
arr = arr.copy()
97+
return torch.from_numpy(arr)
7798

78-
value = value.transpose((2, 0, 1))
99+
# Video/Audio decoder passthrough
79100
if config.TORCHVISION_AVAILABLE and "torchvision" in sys.modules:
80101
from torchvision.io import VideoReader
81102

82103
if isinstance(value, VideoReader):
83-
return value # TODO(QL): set output to torch tensors ?
104+
return value
105+
84106
if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules:
85107
from torchcodec.decoders import AudioDecoder, VideoDecoder
86108

87109
if isinstance(value, (VideoDecoder, AudioDecoder)):
88-
return value # TODO(QL): set output to jax arrays ?
110+
return value
111+
112+
# Support for other tensor libraries via __array__
113+
if hasattr(value, "__array__") and not isinstance(value, torch.Tensor):
114+
value = value.__array__()
115+
116+
# Fast numpy conversion paths
117+
if isinstance(value, np.ndarray):
118+
# Handle integer types with smart casting
119+
if np.issubdtype(value.dtype, np.integer):
120+
# Check if user specified a dtype, otherwise default to int64
121+
kwargs = self.torch_tensor_kwargs.copy()
122+
target_dtype = kwargs.get("dtype", torch.int64)
123+
124+
# Safe casting for unsigned types
125+
if value.dtype in (np.uint16, np.uint32):
126+
# Cast to int64 in numpy (fast) then convert to torch
127+
value = value.astype(np.int64)
128+
if target_dtype == torch.int64:
129+
if not value.flags.writeable:
130+
value = value.copy()
131+
return torch.from_numpy(value)
132+
else:
133+
if not value.flags.writeable:
134+
value = value.copy()
135+
kwargs.setdefault("dtype", target_dtype)
136+
return torch.as_tensor(value, **kwargs)
137+
elif value.dtype == np.uint64:
138+
# Check if values fit in int64 range
139+
if np.all(value <= np.iinfo(np.int64).max):
140+
value = value.astype(np.int64)
141+
if target_dtype == torch.int64:
142+
if not value.flags.writeable:
143+
value = value.copy()
144+
return torch.from_numpy(value)
145+
else:
146+
if not value.flags.writeable:
147+
value = value.copy()
148+
kwargs.setdefault("dtype", target_dtype)
149+
return torch.as_tensor(value, **kwargs)
150+
else:
151+
# Fallback to safe conversion via Python ints
152+
kwargs.setdefault("dtype", target_dtype)
153+
return torch.tensor(value, **kwargs)
154+
else:
155+
# Use zero-copy conversion for compatible integer types
156+
if value.dtype == np.int64 and target_dtype == torch.int64:
157+
# Perfect match, zero-copy conversion
158+
if not value.flags.writeable:
159+
value = value.copy()
160+
return torch.from_numpy(value)
161+
else:
162+
# Need dtype conversion, use as_tensor for efficiency
163+
if not value.flags.writeable:
164+
value = value.copy()
165+
kwargs.setdefault("dtype", target_dtype)
166+
return torch.as_tensor(value, **kwargs)
167+
168+
# Handle floating point types
169+
elif np.issubdtype(value.dtype, np.floating):
170+
# Check if user specified a dtype, otherwise default to float32
171+
kwargs = self.torch_tensor_kwargs.copy()
172+
target_dtype = kwargs.get("dtype", torch.float32)
173+
174+
if value.dtype == np.float32 and target_dtype == torch.float32:
175+
# Zero-copy conversion, but ensure array is writable
176+
if not value.flags.writeable:
177+
value = value.copy()
178+
return torch.from_numpy(value)
179+
else:
180+
# Need dtype conversion
181+
if not value.flags.writeable:
182+
value = value.copy()
183+
kwargs.setdefault("dtype", target_dtype)
184+
return torch.as_tensor(value, **kwargs)
185+
else:
186+
# Other numpy types, use zero-copy when possible
187+
if not value.flags.writeable:
188+
value = value.copy()
189+
return torch.from_numpy(value)
190+
191+
# Handle numpy scalars
192+
elif isinstance(value, np.number):
193+
kwargs = self.torch_tensor_kwargs.copy()
194+
if np.issubdtype(value.dtype, np.integer):
195+
# Use torch.as_tensor for scalar conversion with dtype control
196+
kwargs.setdefault("dtype", torch.int64)
197+
return torch.as_tensor(value, **kwargs)
198+
elif np.issubdtype(value.dtype, np.floating):
199+
kwargs.setdefault("dtype", torch.float32)
200+
return torch.as_tensor(value, **kwargs)
201+
else:
202+
return torch.as_tensor(value, **kwargs)
203+
204+
# Handle Python lists/tuples of numbers efficiently
205+
elif isinstance(value, (list, tuple)):
206+
# Try to convert to numpy first for faster tensor creation
207+
try:
208+
arr = np.array(value)
209+
if arr.dtype.kind in "iuf": # integer, unsigned, float
210+
return self._tensorize(arr) # Recursive call to handle numpy path
211+
except (ValueError, TypeError):
212+
pass # Fall back to torch.tensor
213+
214+
# Default fallback with dtype defaults
215+
default_dtype = {}
216+
if isinstance(value, (int, float)):
217+
if isinstance(value, int):
218+
default_dtype = {"dtype": torch.int64}
219+
else:
220+
default_dtype = {"dtype": torch.float32}
89221

90222
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
91223

92224
def _recursive_tensorize(self, data_struct):
93-
import torch
94-
95-
# support for torch, tf, jax etc.
225+
"""Optimized recursive walker with reduced Python overhead."""
226+
# Handle tensor-like objects with __array__ interface
96227
if hasattr(data_struct, "__array__") and not isinstance(data_struct, torch.Tensor):
97228
data_struct = data_struct.__array__()
98-
# support for nested types like struct of list of struct
229+
230+
# Handle object arrays (nested structures)
99231
if isinstance(data_struct, np.ndarray):
100-
if data_struct.dtype == object: # torch tensors cannot be instantied from an array of objects
101-
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
232+
if data_struct.dtype == object:
233+
# Use list comprehension instead of map_nested
234+
result = [self._recursive_tensorize(item) for item in data_struct]
235+
return self._consolidate(result)
236+
# Handle lists and tuples
102237
elif isinstance(data_struct, (list, tuple)):
103-
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
238+
result = [self._recursive_tensorize(item) for item in data_struct]
239+
return self._consolidate(result)
240+
# Handle dictionaries
241+
elif isinstance(data_struct, dict):
242+
return {key: self._recursive_tensorize(value) for key, value in data_struct.items()}
243+
244+
# Base case: tensorize the leaf value
104245
return self._tensorize(data_struct)
105246

106247
def recursive_tensorize(self, data_struct: dict):
107-
return map_nested(self._recursive_tensorize, data_struct, map_list=False)
248+
"""Public interface maintaining compatibility."""
249+
return self._recursive_tensorize(data_struct)
108250

109251
def format_row(self, pa_table: pa.Table) -> Mapping:
110252
row = self.numpy_arrow_extractor().extract_row(pa_table)

0 commit comments

Comments
 (0)