Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2961,7 +2961,7 @@ def init_buffer_and_writer():
else:
writer.write(example)
else:
for i, batch in enumerate(pbar):
for i, batch in zip(range(0, num_rows, batch_size), pbar):
indices = list(
range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
) # Something simpler?
Expand Down
6 changes: 6 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3081,6 +3081,12 @@ def test_dataset_add_item_introduce_feature_type():
assert dataset[:] == {"col_1": [None, None, None, "a"]}


def test_dataset_filter_indices():
ds = Dataset.from_dict({"num": [0, 1, 2, 3]})
ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2)
assert all(item["num"] % 2 == 0 for item in ds)


@pytest.mark.parametrize("in_memory", [False, True])
def test_dataset_from_file(in_memory, dataset, arrow_file):
filename = arrow_file
Expand Down