@@ -1854,17 +1854,64 @@ def __len__(self):
18541854 """
18551855 return self .num_rows
18561856
1857+ def _iter_batches (self , batch_size : int , decoded : bool = True ):
1858+ """Iterate through the batches of size `batch_size`.
1859+
1860+ If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
1861+ selected format.
1862+ """
1863+ if self ._indices is None and config .PYARROW_VERSION .major >= 8 :
1864+ # Fast iteration
1865+ # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch)
1866+ format_kwargs = self ._format_kwargs if self ._format_kwargs is not None else {}
1867+ formatter = get_formatter (self ._format_type , features = self .features , decoded = decoded , ** format_kwargs )
1868+ for batch in self .data .to_reader (max_chunksize = batch_size ):
1869+ pa_subtable = pa .Table .from_batches ([batch ])
1870+ formatted_output = format_table (
1871+ pa_subtable ,
1872+ range (pa_subtable .num_rows ),
1873+ formatter = formatter ,
1874+ format_columns = self ._format_columns ,
1875+ output_all_columns = self ._output_all_columns ,
1876+ )
1877+ yield formatted_output
1878+ else :
1879+ for i in range (0 , self .num_rows , batch_size ):
1880+ yield self ._getitem (
1881+ slice (i , i + batch_size ),
1882+ decoded = decoded ,
1883+ )
1884+
18571885 def _iter (self , decoded : bool = True ):
18581886 """Iterate through the examples.
18591887
18601888 If a formatting is set with :meth:`Dataset.set_format` rows will be returned with the
18611889 selected format.
18621890 """
1863- for index in range (self .num_rows ):
1864- yield self ._getitem (
1865- index ,
1866- decoded = decoded ,
1867- )
1891+ if self ._indices is None and config .PYARROW_VERSION .major >= 8 :
1892+ # Fast iteration
1893+ # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch)
1894+ format_kwargs = self ._format_kwargs if self ._format_kwargs is not None else {}
1895+ formatter = get_formatter (self ._format_type , features = self .features , decoded = decoded , ** format_kwargs )
1896+ batch_size = config .ARROW_READER_BATCH_SIZE_IN_DATASET_ITER
1897+ for batch in self .data .to_reader (max_chunksize = batch_size ):
1898+ for i in range (batch .num_rows ):
1899+ batch_ex = batch .slice (i , 1 )
1900+ pa_subtable = pa .Table .from_batches ([batch_ex ])
1901+ formatted_output = format_table (
1902+ pa_subtable ,
1903+ 0 ,
1904+ formatter = formatter ,
1905+ format_columns = self ._format_columns ,
1906+ output_all_columns = self ._output_all_columns ,
1907+ )
1908+ yield formatted_output
1909+ else :
1910+ for i in range (self .num_rows ):
1911+ yield self ._getitem (
1912+ i ,
1913+ decoded = decoded ,
1914+ )
18681915
18691916 def __iter__ (self ):
18701917 """Iterate through the examples.
@@ -2805,14 +2852,16 @@ def init_buffer_and_writer():
28052852
28062853 # Loop over single examples or batches and write to buffer/file if examples are to be updated
28072854 if not batched :
2808- pbar_iterable = input_dataset ._iter (decoded = False )
28092855 pbar_total = len (input_dataset )
2856+ pbar_iterable = input_dataset ._iter (decoded = False )
28102857 else :
28112858 num_rows = (
28122859 len (input_dataset ) if not drop_last_batch else len (input_dataset ) // batch_size * batch_size
28132860 )
2814- pbar_iterable = range (0 , num_rows , batch_size )
28152861 pbar_total = (num_rows // batch_size ) + 1 if num_rows % batch_size else num_rows // batch_size
2862+ pbar_iterable = itertools .islice (
2863+ input_dataset ._iter_batches (batch_size , decoded = False ), pbar_total
2864+ )
28162865 pbar_unit = "ex" if not batched else "ba"
28172866 pbar_desc = (desc + " " if desc is not None else "" ) + "#" + str (rank ) if rank is not None else desc
28182867 pbar = logging .tqdm (
@@ -2835,11 +2884,7 @@ def init_buffer_and_writer():
28352884 else :
28362885 writer .write (example )
28372886 else :
2838- for i in pbar :
2839- batch = input_dataset ._getitem (
2840- slice (i , i + batch_size ),
2841- decoded = False ,
2842- )
2887+ for i , batch in enumerate (pbar ):
28432888 indices = list (
28442889 range (* (slice (i , i + batch_size ).indices (input_dataset .num_rows )))
28452890 ) # Something simpler?
0 commit comments