Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,10 @@ def add_elasticsearch_index(
)
return self

def add_item(self, item):
table = pa.Table.from_pydict({k: [v] for k, v in item.items()})
self._data = pa.concat_tables([self._data, table]) if self._data.shape != (0, 0) else table


def concatenate_datasets(
dsets: List[Dataset],
Expand Down
8 changes: 8 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,3 +1666,11 @@ def test_dataset_from_file(in_memory, dataset, arrow_file):
assert dataset_from_file.features == dataset.features
assert dataset_from_file.cache_files == ([{"filename": filename}] if not in_memory else [])
assert increased_allocated_memory == in_memory


def test_dataset_add_item():
item = {"input_ids": np.array([4, 4, 2])}
ds = Dataset(pa.Table.from_pydict({}))
ds.add_item(item)
assert ds.data.shape == (1, 1)
assert ds.data.column_names == ["input_ids"]