Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class IndexedTableMixin:
def __init__(self, table: pa.Table):
self._schema = table.schema
self._batches = table.to_batches()
self._offsets: np.ndarray = np.cumsum([0] + [len(b) for b in self._batches])
self._offsets: List[int] = np.cumsum([0] + [len(b) for b in self._batches]).tolist()

def fast_gather(self, indices: Union[List[int], np.ndarray]) -> pa.Table:
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def assert_pickle_does_bring_data_in_memory(table: MemoryMappedTable):

def assert_index_attributes_equal(table: Table, other: Table):
assert table._batches == other._batches
np.testing.assert_array_equal(table._offsets, other._offsets)
assert table._offsets == other._offsets
assert table._schema == other._schema


Expand Down Expand Up @@ -901,8 +901,8 @@ def getitem_unique_count(self):

@pytest.mark.parametrize(
"arr, x",
[(np.arange(0, 14, 3), x) for x in range(-1, 22)]
+ [(list(np.arange(-5, 5)), x) for x in range(-6, 6)]
[(list(range(0, 14, 3)), x) for x in range(-1, 22)]
+ [(list(range(-5, 5)), x) for x in range(-6, 6)]
+ [([0, 1_000, 1_001, 1_003], x) for x in [-1, 0, 2, 100, 999, 1_000, 1_001, 1_002, 1_003, 1_004]]
+ [(list(range(1_000)), x) for x in [-1, 0, 1, 10, 666, 999, 1_000, 1_0001]],
)
Expand All @@ -926,6 +926,6 @@ def test_indexed_table_mixin():
pa_table = pa.Table.from_pydict({"col": [0] * n_rows_per_chunk})
pa_table = pa.concat_tables([pa_table] * n_chunks)
table = Table(pa_table)
assert all(table._offsets.tolist() == np.cumsum([0] + [n_rows_per_chunk] * n_chunks))
assert all(table._offsets == np.cumsum([0] + [n_rows_per_chunk] * n_chunks))
assert table.fast_slice(5) == pa_table.slice(5)
assert table.fast_slice(2, 13) == pa_table.slice(2, 13)