Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Main classes
The base class :class:`datasets.Dataset` implements a Dataset backed by an Apache Arrow table.

.. autoclass:: datasets.Dataset
:members: from_file, from_buffer, from_pandas, from_dict,
:members:
add_item,
from_file, from_buffer, from_pandas, from_dict,
data, cache_files, num_columns, num_rows, column_names, shape,
unique,
flatten_, cast_, remove_columns_, rename_column_,
Expand Down
20 changes: 20 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,26 @@ def add_elasticsearch_index(
)
return self

def add_item(self, item: dict):
"""Add item to Dataset.

.. versionadded:: 1.6

Args:
item (dict): Item data to be added.

Returns:
:class:`Dataset`
"""
item_table = InMemoryTable.from_pydict({k: [v] for k, v in item.items()})
# Cast item
type = self.features.type
schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names})
item_table = item_table.cast(schema)
# Concatenate tables
table = concat_tables([self._data, item_table])
return Dataset(table)


def concatenate_datasets(
dsets: List[Dataset],
Expand Down
28 changes: 28 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,34 @@ def test_concatenate_datasets_duplicate_columns(dataset):
assert "duplicated" in str(excinfo.value)


@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"item",
[
{"col_1": "4", "col_2": 4, "col_3": 4.0},
{"col_1": "4", "col_2": "4", "col_3": "4"},
{"col_1": 4, "col_2": 4, "col_3": 4},
{"col_1": 4.0, "col_2": 4.0, "col_3": 4.0},
],
)
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path):
dataset = (
Dataset(InMemoryTable.from_pydict(dataset_dict))
if in_memory
else Dataset(MemoryMappedTable.from_file(arrow_path))
)
dataset = dataset.add_item(item)
assert dataset.data.shape == (5, 3)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
assert dataset.data.column_names == list(expected_features.keys())
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
dataset = dataset.add_item(item)
assert dataset.data.shape == (6, 3)
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one


@pytest.mark.parametrize("keep_in_memory", [False, True])
@pytest.mark.parametrize(
"features",
Expand Down