diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 909f666d4dc..bc7ab6c9492 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2929,14 +2929,15 @@ def init_buffer_and_writer(): # Loop over single examples or batches and write to buffer/file if examples are to be updated if not batched: pbar_total = len(input_dataset) - pbar_iterable = input_dataset._iter(decoded=False) + pbar_iterable = enumerate(input_dataset._iter(decoded=False)) else: num_rows = ( len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size ) pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - pbar_iterable = input_dataset._iter_batches( - batch_size, decoded=False, drop_last_batch=drop_last_batch + pbar_iterable = zip( + range(0, num_rows, batch_size), + input_dataset._iter_batches(batch_size, decoded=False, drop_last_batch=drop_last_batch), ) pbar_unit = "ex" if not batched else "ba" pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc @@ -2949,7 +2950,7 @@ def init_buffer_and_writer(): desc=pbar_desc, ) if not batched: - for i, example in enumerate(pbar): + for i, example in pbar: example = apply_function_on_filtered_inputs(example, i, offset=offset) if update_data: if i == 0: @@ -2960,7 +2961,7 @@ def init_buffer_and_writer(): else: writer.write(example) else: - for i, batch in zip(range(0, num_rows, batch_size), pbar): + for i, batch in pbar: indices = list( range(*(slice(i, i + batch_size).indices(input_dataset.num_rows))) ) # Something simpler?