diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index cef58f560ed..0f0cb7ae3c9 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -775,6 +775,7 @@ def class_encode_column(self, column: str) -> "Dataset": class_names = sorted(dset.unique(column)) dst_feat = ClassLabel(names=class_names) dset = dset.map(lambda batch: {column: dst_feat.str2int(batch)}, input_columns=column, batched=True) + dset = concatenate_datasets([self.remove_columns([column]), dset], axis=1) new_features = copy.deepcopy(dset.features) new_features[column] = dst_feat @@ -2877,10 +2878,12 @@ def add_elasticsearch_index( ) return self - def add_item(self, item: dict): + @transmit_format + @fingerprint_transform(inplace=False) + def add_item(self, item: dict, new_fingerprint: str): """Add item to Dataset. - .. versionadded:: 1.6 + .. versionadded:: 1.7 Args: item (dict): Item data to be added. @@ -2894,7 +2897,19 @@ def add_item(self, item: dict): item_table = item_table.cast(schema) # Concatenate tables table = concat_tables([self._data, item_table]) - return Dataset(table) + if self._indices is None: + indices_table = None + else: + item_indices_array = pa.array([len(self._data)], type=pa.uint64()) + item_indices_table = InMemoryTable.from_arrays([item_indices_array], names=["indices"]) + indices_table = concat_tables([self._indices, item_indices_table]) + return Dataset( + table, + info=copy.deepcopy(self.info), + split=self.split, + indices_table=indices_table, + fingerprint=new_fingerprint, + ) def concatenate_datasets( diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 245136e48b7..2b0367218e7 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1949,6 +1949,10 @@ def test_concatenate_datasets_duplicate_columns(dataset): assert "duplicated" in str(excinfo.value) +@pytest.mark.parametrize( + "transform", + [None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})], +) @pytest.mark.parametrize("in_memory", [False, True]) @pytest.mark.parametrize( "item", @@ -1959,22 +1963,32 @@ def test_concatenate_datasets_duplicate_columns(dataset): {"col_1": 4.0, "col_2": 4.0, "col_3": 4.0}, ], ) -def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path): - dataset = ( +def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform): + dataset_to_test = ( Dataset(InMemoryTable.from_pydict(dataset_dict)) if in_memory else Dataset(MemoryMappedTable.from_file(arrow_path)) ) - dataset = dataset.add_item(item) + if transform is not None: + transform_name, args, kwargs = transform + dataset_to_test: Dataset = getattr(dataset_to_test, transform_name)(*args, **kwargs) + dataset = dataset_to_test.add_item(item) assert dataset.data.shape == (5, 3) - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - assert dataset.data.column_names == list(expected_features.keys()) + expected_features = dataset_to_test.features + assert sorted(dataset.data.column_names) == sorted(expected_features.keys()) for feature, expected_dtype in expected_features.items(): - assert dataset.features[feature].dtype == expected_dtype - assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one - dataset = dataset.add_item(item) - assert dataset.data.shape == (6, 3) + assert dataset.features[feature] == expected_dtype assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one + assert dataset.format["type"] == dataset_to_test.format["type"] + assert dataset._fingerprint != dataset_to_test._fingerprint + dataset.reset_format() + dataset_to_test.reset_format() + assert dataset[:-1] == dataset_to_test[:] + assert {k: int(v) for k, v in dataset[-1].items()} == {k: int(v) for k, v in item.items()} + if dataset._indices is not None: + dataset_indices = dataset._indices["indices"].to_pylist() + dataset_to_test_indices = dataset_to_test._indices["indices"].to_pylist() + assert dataset_indices == dataset_to_test_indices + [len(dataset_to_test._data)] @pytest.mark.parametrize("keep_in_memory", [False, True])