diff --git a/docs/source/package_reference/main_classes.rst b/docs/source/package_reference/main_classes.rst index 2eec94c0f02..0f43149fca3 100644 --- a/docs/source/package_reference/main_classes.rst +++ b/docs/source/package_reference/main_classes.rst @@ -15,7 +15,7 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach .. autoclass:: datasets.Dataset :members: - add_item, + add_column, add_item, from_file, from_buffer, from_pandas, from_dict, data, cache_files, num_columns, num_rows, column_names, shape, unique, diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 8213d331ce8..f0e7fa93bb8 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -56,7 +56,7 @@ from .info import DatasetInfo from .search import IndexableMixin from .splits import NamedSplit -from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files +from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files from .utils import map_nested from .utils.deprecation_utils import deprecated from .utils.file_utils import estimate_dataset_size @@ -2710,6 +2710,29 @@ def to_pandas( for offset in range(0, len(self), batch_size) ) + @transmit_format + @fingerprint_transform(inplace=False) + def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str): + """Add column to Dataset. + + .. versionadded:: 1.7 + + Args: + name (str): Column name. + column (list or np.array): Column data to be added. + + Returns: + :class:`Dataset` + """ + column_table = InMemoryTable.from_pydict({name: column}) + # Concatenate tables horizontally + table = ConcatenationTable.from_tables([self._data, column_table], axis=1) + # Update features + info = copy.deepcopy(self.info) + info.features.update(Features.from_arrow_schema(column_table.schema)) + table = update_metadata_with_features(table, info.features) + return Dataset(table, info=info, split=self.split, indices_table=self._indices, fingerprint=new_fingerprint) + def add_faiss_index( self, column: str, diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index cdadac43158..5be9565642b 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1992,6 +1992,45 @@ def test_concatenate_datasets_duplicate_columns(dataset): assert "duplicated" in str(excinfo.value) +@pytest.mark.parametrize( + "column, expected_dtype", + [(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")], +) +@pytest.mark.parametrize("in_memory", [False, True]) +@pytest.mark.parametrize( + "transform", + [None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})], +) +def test_dataset_add_column(column, expected_dtype, in_memory, transform, dataset_dict, arrow_path): + column_name = "col_4" + original_dataset = ( + Dataset(InMemoryTable.from_pydict(dataset_dict)) + if in_memory + else Dataset(MemoryMappedTable.from_file(arrow_path)) + ) + if transform is not None: + transform_name, args, kwargs = transform + original_dataset: Dataset = getattr(original_dataset, transform_name)(*args, **kwargs) + dataset = original_dataset.add_column(column_name, column) + assert dataset.data.shape == (4, 4) + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64", column_name: expected_dtype} + assert dataset.data.column_names == list(expected_features.keys()) + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one + assert dataset.format["type"] == original_dataset.format["type"] + assert dataset._fingerprint != original_dataset._fingerprint + dataset.reset_format() + original_dataset.reset_format() + assert all(dataset[col] == original_dataset[col] for col in original_dataset.column_names) + assert set(dataset["col_4"]) == set(column) + if dataset._indices is not None: + dataset_indices = dataset._indices["indices"].to_pylist() + expected_dataset_indices = original_dataset._indices["indices"].to_pylist() + assert dataset_indices == expected_dataset_indices + assert_arrow_metadata_are_synced_with_dataset_features(dataset) + + @pytest.mark.parametrize( "transform", [None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],