Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def class_encode_column(self, column: str) -> "Dataset":
class_names = sorted(dset.unique(column))
dst_feat = ClassLabel(names=class_names)
dset = dset.map(lambda batch: {column: dst_feat.str2int(batch)}, input_columns=column, batched=True)
dset = concatenate_datasets([self.remove_columns([column]), dset], axis=1)

new_features = copy.deepcopy(dset.features)
new_features[column] = dst_feat
Expand Down Expand Up @@ -2877,10 +2878,12 @@ def add_elasticsearch_index(
)
return self

def add_item(self, item: dict):
@transmit_format
@fingerprint_transform(inplace=False)
def add_item(self, item: dict, new_fingerprint: str):
"""Add item to Dataset.

.. versionadded:: 1.6
.. versionadded:: 1.7

Args:
item (dict): Item data to be added.
Expand All @@ -2894,7 +2897,19 @@ def add_item(self, item: dict):
item_table = item_table.cast(schema)
# Concatenate tables
table = concat_tables([self._data, item_table])
return Dataset(table)
if self._indices is None:
indices_table = None
else:
item_indices_array = pa.array([len(self._data)], type=pa.uint64())
item_indices_table = InMemoryTable.from_arrays([item_indices_array], names=["indices"])
indices_table = concat_tables([self._indices, item_indices_table])
return Dataset(
table,
info=copy.deepcopy(self.info),
split=self.split,
indices_table=indices_table,
fingerprint=new_fingerprint,
)


def concatenate_datasets(
Expand Down
32 changes: 23 additions & 9 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,10 @@ def test_concatenate_datasets_duplicate_columns(dataset):
assert "duplicated" in str(excinfo.value)


@pytest.mark.parametrize(
"transform",
[None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],
)
@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"item",
Expand All @@ -1959,22 +1963,32 @@ def test_concatenate_datasets_duplicate_columns(dataset):
{"col_1": 4.0, "col_2": 4.0, "col_3": 4.0},
],
)
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path):
dataset = (
def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform):
dataset_to_test = (
Dataset(InMemoryTable.from_pydict(dataset_dict))
if in_memory
else Dataset(MemoryMappedTable.from_file(arrow_path))
)
dataset = dataset.add_item(item)
if transform is not None:
transform_name, args, kwargs = transform
dataset_to_test: Dataset = getattr(dataset_to_test, transform_name)(*args, **kwargs)
dataset = dataset_to_test.add_item(item)
assert dataset.data.shape == (5, 3)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
assert dataset.data.column_names == list(expected_features.keys())
expected_features = dataset_to_test.features
assert sorted(dataset.data.column_names) == sorted(expected_features.keys())
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
dataset = dataset.add_item(item)
assert dataset.data.shape == (6, 3)
assert dataset.features[feature] == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
assert dataset.format["type"] == dataset_to_test.format["type"]
assert dataset._fingerprint != dataset_to_test._fingerprint
dataset.reset_format()
dataset_to_test.reset_format()
assert dataset[:-1] == dataset_to_test[:]
assert {k: int(v) for k, v in dataset[-1].items()} == {k: int(v) for k, v in item.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to introduce an explicit test for the _indices?

For example, this test passes even if I wrongly set in add_item:

new_indices_array = pa.array([9], type=pa.uint64())

if dataset._indices is not None:
dataset_indices = dataset._indices["indices"].to_pylist()
dataset_to_test_indices = dataset_to_test._indices["indices"].to_pylist()
assert dataset_indices == dataset_to_test_indices + [len(dataset_to_test._data)]


@pytest.mark.parametrize("keep_in_memory", [False, True])
Expand Down