Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,14 +2929,15 @@ def init_buffer_and_writer():
# Loop over single examples or batches and write to buffer/file if examples are to be updated
if not batched:
pbar_total = len(input_dataset)
pbar_iterable = input_dataset._iter(decoded=False)
pbar_iterable = enumerate(input_dataset._iter(decoded=False))
else:
num_rows = (
len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size
)
pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
pbar_iterable = input_dataset._iter_batches(
batch_size, decoded=False, drop_last_batch=drop_last_batch
pbar_iterable = zip(
range(0, num_rows, batch_size),
input_dataset._iter_batches(batch_size, decoded=False, drop_last_batch=drop_last_batch),
)
pbar_unit = "ex" if not batched else "ba"
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
Expand All @@ -2949,7 +2950,7 @@ def init_buffer_and_writer():
desc=pbar_desc,
)
if not batched:
for i, example in enumerate(pbar):
for i, example in pbar:
example = apply_function_on_filtered_inputs(example, i, offset=offset)
if update_data:
if i == 0:
Expand All @@ -2960,7 +2961,7 @@ def init_buffer_and_writer():
else:
writer.write(example)
else:
for i, batch in zip(range(0, num_rows, batch_size), pbar):
for i, batch in pbar:
indices = list(
range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
) # Something simpler?
Expand Down