Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
41afaed
Add dataset_dict and arrow_path for tests
albertvillanova Mar 30, 2021
7f227a0
Test Dataset.add_column
albertvillanova Mar 30, 2021
9f99119
Implement Dataset.add_column
albertvillanova Mar 30, 2021
7dd5146
Add docstring
albertvillanova Mar 30, 2021
11e334f
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 20, 2021
76adb79
Use ConcatenationTable.from_tables(axis=1)
albertvillanova Apr 20, 2021
b5a0572
Return new dataset
albertvillanova Apr 20, 2021
a2b3eef
Test multiple InMemoryTables are consolidated
albertvillanova Apr 20, 2021
f73577f
Test for consolidated InMemoryTables after multiple calls
albertvillanova Apr 20, 2021
2c4dc74
Add versionadded to docstring
albertvillanova Apr 20, 2021
9eb611f
Add method docstring to the docs
albertvillanova Apr 20, 2021
8b42f25
Pass column with 2 parameters as column_name and column
albertvillanova Apr 20, 2021
76631e4
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 26, 2021
a5148ea
Change versionadded
albertvillanova Apr 26, 2021
bbb2eed
Update features with new one instead of replacing them all
albertvillanova Apr 26, 2021
9e38c29
Test more column dtypes
albertvillanova Apr 26, 2021
ee0e1c8
Test also format and fingerprint
albertvillanova Apr 26, 2021
8d791eb
Transmit format and fingerprint in add_column
albertvillanova Apr 26, 2021
dec889b
Update table metadata with features
albertvillanova Apr 27, 2021
9924dc9
Test indices
albertvillanova Apr 27, 2021
fdef700
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 27, 2021
365d903
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 27, 2021
1c00d25
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 29, 2021
2196802
Test metadata are synced with features
albertvillanova Apr 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach

.. autoclass:: datasets.Dataset
:members:
add_item,
add_column, add_item,
from_file, from_buffer, from_pandas, from_dict,
data, cache_files, num_columns, num_rows, column_names, shape,
unique,
Expand Down
24 changes: 23 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .info import DATASET_INFO_FILENAME, DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
from .utils import map_nested
from .utils.deprecation_utils import deprecated
from .utils.file_utils import estimate_dataset_size
Expand Down Expand Up @@ -2664,6 +2664,28 @@ def to_pandas(
for offset in range(0, len(self), batch_size)
)

@transmit_format
@fingerprint_transform(inplace=False)
def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str):
"""Add column to Dataset.

.. versionadded:: 1.7

Args:
name (str): Column name.
column (list or np.array): Column data to be added.

Returns:
:class:`Dataset`
"""
column_table = InMemoryTable.from_pydict({name: column})
# Concatenate tables horizontally
table = ConcatenationTable.from_tables([self._data, column_table], axis=1)
# Update features
info = copy.deepcopy(self.info)
info.features.update(Features.from_arrow_schema(column_table.schema))
return Dataset(table, info=info, split=self.split, indices_table=self._indices, fingerprint=new_fingerprint)

def add_faiss_index(
self,
column: str,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,40 @@ def test_concatenate_datasets_duplicate_columns(dataset):
assert "duplicated" in str(excinfo.value)


@pytest.mark.parametrize(
"column, expected_dtype",
[(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")],
)
@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"transform",
[None, ("shuffle", (42,), {}), ("with_format", ("pandas",), {}), ("class_encode_column", ("col_2",), {})],
)
def test_dataset_add_column(column, expected_dtype, in_memory, transform, dataset_dict, arrow_path):
column_name = "col_4"
original_dataset = (
Dataset(InMemoryTable.from_pydict(dataset_dict))
if in_memory
else Dataset(MemoryMappedTable.from_file(arrow_path))
)
if transform is not None:
transform_name, args, kwargs = transform
original_dataset: Dataset = getattr(original_dataset, transform_name)(*args, **kwargs)
dataset = original_dataset.add_column(column_name, column)
assert dataset.data.shape == (4, 4)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64", column_name: expected_dtype}
assert dataset.data.column_names == list(expected_features.keys())
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
assert dataset.format["type"] == original_dataset.format["type"]
assert dataset._fingerprint != original_dataset._fingerprint
dataset.reset_format()
original_dataset.reset_format()
assert all(dataset[col] == original_dataset[col] for col in original_dataset.column_names)
assert set(dataset["col_4"]) == set(column)


@pytest.mark.parametrize("in_memory", [False, True])
@pytest.mark.parametrize(
"item",
Expand Down