Skip to content

Commit d60f5ff

Browse files
Fix filter indices when batched (#5113)
* Test filter indices * Fix filter indices when batched * Rename test
1 parent 3ad9644 commit d60f5ff

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2961,7 +2961,7 @@ def init_buffer_and_writer():
29612961
else:
29622962
writer.write(example)
29632963
else:
2964-
for i, batch in enumerate(pbar):
2964+
for i, batch in zip(range(0, num_rows, batch_size), pbar):
29652965
indices = list(
29662966
range(*(slice(i, i + batch_size).indices(input_dataset.num_rows)))
29672967
) # Something simpler?

tests/test_arrow_dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,12 @@ def test_dataset_add_item_introduce_feature_type():
30813081
assert dataset[:] == {"col_1": [None, None, None, "a"]}
30823082

30833083

3084+
def test_dataset_filter_batched_indices():
3085+
ds = Dataset.from_dict({"num": [0, 1, 2, 3]})
3086+
ds = ds.filter(lambda num: num % 2 == 0, input_columns="num", batch_size=2)
3087+
assert all(item["num"] % 2 == 0 for item in ds)
3088+
3089+
30843090
@pytest.mark.parametrize("in_memory", [False, True])
30853091
def test_dataset_from_file(in_memory, dataset, arrow_file):
30863092
filename = arrow_file

0 commit comments

Comments
 (0)