Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
41afaed
Add dataset_dict and arrow_path for tests
albertvillanova Mar 30, 2021
7f227a0
Test Dataset.add_column
albertvillanova Mar 30, 2021
9f99119
Implement Dataset.add_column
albertvillanova Mar 30, 2021
7dd5146
Add docstring
albertvillanova Mar 30, 2021
11e334f
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 20, 2021
76adb79
Use ConcatenationTable.from_tables(axis=1)
albertvillanova Apr 20, 2021
b5a0572
Return new dataset
albertvillanova Apr 20, 2021
a2b3eef
Test multiple InMemoryTables are consolidated
albertvillanova Apr 20, 2021
f73577f
Test for consolidated InMemoryTables after multiple calls
albertvillanova Apr 20, 2021
2c4dc74
Add versionadded to docstring
albertvillanova Apr 20, 2021
9eb611f
Add method docstring to the docs
albertvillanova Apr 20, 2021
8b42f25
Pass column with 2 parameters as column_name and column
albertvillanova Apr 20, 2021
76631e4
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 26, 2021
a5148ea
Change versionadded
albertvillanova Apr 26, 2021
bbb2eed
Update features with new one instead of replacing them all
albertvillanova Apr 26, 2021
9e38c29
Test more column dtypes
albertvillanova Apr 26, 2021
ee0e1c8
Test also format and fingerprint
albertvillanova Apr 26, 2021
8d791eb
Transmit format and fingerprint in add_column
albertvillanova Apr 26, 2021
dec889b
Update table metadata with features
albertvillanova Apr 27, 2021
9924dc9
Test indices
albertvillanova Apr 27, 2021
fdef700
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 27, 2021
365d903
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 27, 2021
1c00d25
Merge remote-tracking branch 'upstream/master' into dataset-add-column
albertvillanova Apr 29, 2021
2196802
Test metadata are synced with features
albertvillanova Apr 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Main classes
The base class :class:`datasets.Dataset` implements a Dataset backed by an Apache Arrow table.

.. autoclass:: datasets.Dataset
:members: from_file, from_buffer, from_pandas, from_dict,
:members:
add_column,
from_file, from_buffer, from_pandas, from_dict,
data, cache_files, num_columns, num_rows, column_names, shape,
unique,
flatten_, cast_, remove_columns_, rename_column_,
Expand Down
22 changes: 21 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .info import DATASET_INFO_FILENAME, DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
from .utils import map_nested
from .utils.deprecation_utils import deprecated
from .utils.logging import WARNING, get_logger, get_verbosity, set_verbosity_warning
Expand Down Expand Up @@ -2654,6 +2654,26 @@ def to_pandas(
for offset in range(0, len(self), batch_size)
)

def add_column(self, name: str, column: Union[list, np.array]):
"""Add column to Dataset.

.. versionadded:: 1.6

Args:
name (str): Column name.
column (list or np.array): Column data to be added.

Returns:
:class:`Dataset`
"""
column_table = InMemoryTable.from_pydict({name: column})
# Concatenate tables horizontally
table = ConcatenationTable.from_tables([self._data, column_table], axis=1)
# Update features
info = self.info
info.features = Features.from_arrow_schema(table.schema)
return Dataset(table, info=info)

def add_faiss_index(
self,
column: str,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,27 @@ def test_concatenate_datasets_duplicate_columns(dataset):
assert "duplicated" in str(excinfo.value)


@pytest.mark.parametrize("in_memory", [False, True])
def test_dataset_add_column(in_memory, dataset_dict, arrow_path):
column_name = "col_4"
column = ["a", "b", "c", "d"]
dataset = (
Dataset(InMemoryTable.from_pydict(dataset_dict))
if in_memory
else Dataset(MemoryMappedTable.from_file(arrow_path))
)
dataset = dataset.add_column(column_name, column)
assert dataset.data.shape == (4, 4)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64", "col_4": "string"}
assert dataset.data.column_names == list(expected_features.keys())
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one
column_name = "col_5"
dataset = dataset.add_column(column_name, column)
assert len(dataset.data.blocks) == 1 if in_memory else 2 # multiple InMemoryTables are consolidated as one


@pytest.mark.parametrize("keep_in_memory", [False, True])
@pytest.mark.parametrize(
"features",
Expand Down