Skip to content

Commit 28bcb2f

Browse files
authored
Fix tqdm zip bug (#5120)
* Fix tqdm zip bug * Change pbar_iterable directly, to account for non batch cases * Add enumerate to pbar iterable (non batch mode)
1 parent 85cd129 commit 28bcb2f

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/datasets/arrow_dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,14 +2929,15 @@ def init_buffer_and_writer():
29292929
# Loop over single examples or batches and write to buffer/file if examples are to be updated
29302930
if not batched:
29312931
pbar_total = len(input_dataset)
2932-
pbar_iterable = input_dataset._iter(decoded=False)
2932+
pbar_iterable = enumerate(input_dataset._iter(decoded=False))
29332933
else:
29342934
num_rows = (
29352935
len(input_dataset) if not drop_last_batch else len(input_dataset) // batch_size * batch_size
29362936
)
29372937
pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
2938-
pbar_iterable = input_dataset._iter_batches(
2939-
batch_size, decoded=False, drop_last_batch=drop_last_batch
2938+
pbar_iterable = zip(
2939+
range(0, num_rows, batch_size),
2940+
input_dataset._iter_batches(batch_size, decoded=False, drop_last_batch=drop_last_batch),
29402941
)
29412942
pbar_unit = "ex" if not batched else "ba"
29422943
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
@@ -2949,7 +2950,7 @@ def init_buffer_and_writer():
29492950
desc=pbar_desc,
29502951
)
29512952
if not batched:
2952-
for i, example in enumerate(pbar):
2953+
for i, example in pbar:
29532954
example = apply_function_on_filtered_inputs(example, i, offset=offset)
29542955
if update_data:
29552956
if i == 0:
@@ -2960,7 +2961,7 @@ def init_buffer_and_writer():
29602961
else:
29612962
writer.write(example)
29622963
else:
2963-
for i, batch in zip(range(0, num_rows, batch_size), pbar):
2964+
for i, batch in pbar:
29642965
indices = list(
29652966
range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
29662967
) # Something simpler?

0 commit comments

Comments
 (0)