Skip to content

Commit 1ea4d09

Browse files
mariosaskolhoestq
andauthored
Fast dataset iter (#5030)
* Fast dataset iter * Final improvements + some minor fixes * Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Address comments Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
1 parent 7ee558f commit 1ea4d09

3 files changed

Lines changed: 84 additions & 12 deletions

File tree

src/datasets/arrow_dataset.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,17 +1854,64 @@ def __len__(self):
18541854
"""
18551855
return self.num_rows
18561856

1857+
def _iter_batches(self, batch_size: int, decoded: bool = True):
1858+
"""Iterate through the batches of size `batch_size`.
1859+
1860+
If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
1861+
selected format.
1862+
"""
1863+
if self._indices is None and config.PYARROW_VERSION.major >= 8:
1864+
# Fast iteration
1865+
# Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch)
1866+
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
1867+
formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs)
1868+
for batch in self.data.to_reader(max_chunksize=batch_size):
1869+
pa_subtable = pa.Table.from_batches([batch])
1870+
formatted_output = format_table(
1871+
pa_subtable,
1872+
range(pa_subtable.num_rows),
1873+
formatter=formatter,
1874+
format_columns=self._format_columns,
1875+
output_all_columns=self._output_all_columns,
1876+
)
1877+
yield formatted_output
1878+
else:
1879+
for i in range(0, self.num_rows, batch_size):
1880+
yield self._getitem(
1881+
slice(i, i + batch_size),
1882+
decoded=decoded,
1883+
)
1884+
18571885
def _iter(self, decoded: bool = True):
18581886
"""Iterate through the examples.
18591887
18601888
If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
18611889
selected format.
18621890
"""
1863-
for index in range(self.num_rows):
1864-
yield self._getitem(
1865-
index,
1866-
decoded=decoded,
1867-
)
1891+
if self._indices is None and config.PYARROW_VERSION.major >= 8:
1892+
# Fast iteration
1893+
# Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch)
1894+
format_kwargs = self._format_kwargs if self._format_kwargs is not None else {}
1895+
formatter = get_formatter(self._format_type, features=self.features, decoded=decoded, **format_kwargs)
1896+
batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER
1897+
for batch in self.data.to_reader(max_chunksize=batch_size):
1898+
for i in range(batch.num_rows):
1899+
batch_ex = batch.slice(i, 1)
1900+
pa_subtable = pa.Table.from_batches([batch_ex])
1901+
formatted_output = format_table(
1902+
pa_subtable,
1903+
0,
1904+
formatter=formatter,
1905+
format_columns=self._format_columns,
1906+
output_all_columns=self._output_all_columns,
1907+
)
1908+
yield formatted_output
1909+
else:
1910+
for i in range(self.num_rows):
1911+
yield self._getitem(
1912+
i,
1913+
decoded=decoded,
1914+
)
18681915

18691916
def __iter__(self):
18701917
"""Iterate through the examples.
@@ -2805,14 +2852,16 @@ def init_buffer_and_writer():
28052852

28062853
# Loop over single examples or batches and write to buffer/file if examples are to be updated
28072854
if not batched:
2808-
pbar_iterable = input_dataset._iter(decoded=False)
28092855
pbar_total = len(input_dataset)
2856+
pbar_iterable = input_dataset._iter(decoded=False)
28102857
else:
28112858
num_rows = (
28122859
len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size
28132860
)
2814-
pbar_iterable = range(0, num_rows, batch_size)
28152861
pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
2862+
pbar_iterable = itertools.islice(
2863+
input_dataset._iter_batches(batch_size, decoded=False), pbar_total
2864+
)
28162865
pbar_unit = "ex" if not batched else "ba"
28172866
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
28182867
pbar = logging.tqdm(
@@ -2835,11 +2884,7 @@ def init_buffer_and_writer():
28352884
else:
28362885
writer.write(example)
28372886
else:
2838-
for i in pbar:
2839-
batch = input_dataset._getitem(
2840-
slice(i, i + batch_size),
2841-
decoded=False,
2842-
)
2887+
for i, batch in enumerate(pbar):
28432888
indices = list(
28442889
range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
28452890
) # Something simpler?

src/datasets/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@
168168
# https://github.com/apache/arrow/blob/master/docs/source/cpp/arrays.rst#size-limitations-and-recommendations)
169169
DEFAULT_MAX_BATCH_SIZE = 10_000
170170

171+
# Size of the preloaded record batch in `Dataset.__iter__`
172+
ARROW_READER_BATCH_SIZE_IN_DATASET_ITER = 10
173+
171174
# Pickling tables works only for small tables (<4GiB)
172175
# For big tables, we write them on disk instead
173176
MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30

src/datasets/table.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,30 @@ def to_pandas(self, *args, **kwargs):
330330
def to_string(self, *args, **kwargs):
331331
return self.table.to_string(*args, **kwargs)
332332

333+
def to_reader(self, *args, **kwargs):
334+
"""
335+
Convert the Table to a RecordBatchReader.
336+
337+
Note that this method is zero-copy, it merely exposes the same data under a different API.
338+
339+
Args:
340+
max_chunksize (:obj:`int`, defaults to :obj:`None`)
341+
Maximum size for RecordBatch chunks. Individual chunks may be smaller depending
342+
on the chunk layout of individual columns.
343+
344+
Returns:
345+
:obj:`pyarrow.RecordBatchReader`
346+
347+
<Tip warning={true}>
348+
349+
pyarrow >= 8.0.0 needs to be installed to use this method.
350+
351+
</Tip>
352+
"""
353+
if config.PYARROW_VERSION.major < 8:
354+
raise NotImplementedError("`pyarrow>=8.0.0` is required to use this method")
355+
return self.table.to_reader(*args, **kwargs)
356+
333357
def field(self, *args, **kwargs):
334358
"""
335359
Select a schema field by its column name or numeric index.

0 commit comments

Comments
 (0)