@@ -2929,14 +2929,15 @@ def init_buffer_and_writer():
29292929 # Loop over single examples or batches and write to buffer/file if examples are to be updated
29302930 if not batched :
29312931 pbar_total = len (input_dataset )
2932- pbar_iterable = input_dataset ._iter (decoded = False )
2932+ pbar_iterable = enumerate ( input_dataset ._iter (decoded = False ) )
29332933 else :
29342934 num_rows = (
29352935 len (input_dataset ) if not drop_last_batch else len (input_dataset ) // batch_size * batch_size
29362936 )
29372937 pbar_total = (num_rows // batch_size ) + 1 if num_rows % batch_size else num_rows // batch_size
2938- pbar_iterable = input_dataset ._iter_batches (
2939- batch_size , decoded = False , drop_last_batch = drop_last_batch
2938+ pbar_iterable = zip (
2939+ range (0 , num_rows , batch_size ),
2940+ input_dataset ._iter_batches (batch_size , decoded = False , drop_last_batch = drop_last_batch ),
29402941 )
29412942 pbar_unit = "ex" if not batched else "ba"
29422943 pbar_desc = (desc + " " if desc is not None else "" ) + "#" + str (rank ) if rank is not None else desc
@@ -2949,7 +2950,7 @@ def init_buffer_and_writer():
29492950 desc = pbar_desc ,
29502951 )
29512952 if not batched :
2952- for i , example in enumerate ( pbar ) :
2953+ for i , example in pbar :
29532954 example = apply_function_on_filtered_inputs (example , i , offset = offset )
29542955 if update_data :
29552956 if i == 0 :
@@ -2960,7 +2961,7 @@ def init_buffer_and_writer():
29602961 else :
29612962 writer .write (example )
29622963 else :
2963- for i , batch in zip ( range ( 0 , num_rows , batch_size ), pbar ) :
2964+ for i , batch in pbar :
29642965 indices = list (
29652966 range (* (slice (i , i + batch_size ).indices (input_dataset .num_rows )))
29662967 ) # Something simpler?
0 commit comments