Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
- remove_columns
- rename_column
- rename_columns
- select_columns
- class_encode_column
- __len__
- __iter__
Expand Down Expand Up @@ -137,6 +138,7 @@ It also has dataset transform methods like map or filter, to process all the spl
- remove_columns
- rename_column
- rename_columns
- select_columns
- class_encode_column
- push_to_hub
- save_to_disk
Expand All @@ -156,6 +158,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
[[autodoc]] datasets.IterableDataset
- from_generator
- remove_columns
- select_columns
- cast_column
- cast
- __iter__
Expand Down Expand Up @@ -196,6 +199,7 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable
- remove_columns
- rename_column
- rename_columns
- select_columns

## Features

Expand Down
49 changes: 49 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,55 @@ def rename(columns):
dataset._fingerprint = new_fingerprint
return dataset

@transmit_tasks
@transmit_format
@fingerprint_transform(inplace=False)
def select_columns(self, column_names: Union[str, List[str]], new_fingerprint: Optional[str] = None) -> "Dataset":
"""Select one or several column(s) in the dataset and the features
associated to them.

Args:
column_names (`Union[str, List[str]]`):
Name of the column(s) to keep.
new_fingerprint (`str`, *optional*):
The new fingerprint of the dataset after transform. If `None`,
the new fingerprint is computed using a hash of the previous
fingerprint, and the transform arguments.

Returns:
[`Dataset`]: A copy of the dataset object which only consists of
selected columns.

Example:

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="validation")
>>> ds.select_columns(['text'])
Dataset({
features: ['text'],
num_rows: 1066
})
```
"""
if isinstance(column_names, str):
column_names = [column_names]

for column_name in column_names:
if column_name not in self._data.column_names:
raise ValueError(
f"Column name {column_name} not in the "
"dataset. Current columns in the dataset: "
f"{self._data.column_names}."
)

dataset = copy.deepcopy(self)
dataset._info.features = Features({k: v for k, v in dataset._info.features.items() if k in column_names})
dataset._data = dataset._data.select(column_names)
dataset._data = update_metadata_with_features(dataset._data, dataset.features)
dataset._fingerprint = new_fingerprint
return dataset

def __len__(self):
"""Number of rows in the dataset.

Expand Down
62 changes: 62 additions & 0 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,42 @@ def rename_columns(self, column_mapping: Dict[str, str]) -> "DatasetDict":
self._check_values_type()
return DatasetDict({k: dataset.rename_columns(column_mapping=column_mapping) for k, dataset in self.items()})

def select_columns(self, column_names: Union[str, List[str]]) -> "DatasetDict":
"""Select one or several column(s) from each split in the dataset and
the features associated to the column(s).

The transformation is applied to all the splits of the dataset
dictionary.

Args:
column_names (`Union[str, List[str]]`):
Name of the column(s) to keep.

Example:

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes")
>>> ds.select_columns("text")
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 8530
})
validation: Dataset({
features: ['text'],
num_rows: 1066
})
test: Dataset({
features: ['text'],
num_rows: 1066
})
})
```
"""
self._check_values_type()
return DatasetDict({k: dataset.select_columns(column_names=column_names) for k, dataset in self.items()})

def class_encode_column(self, column: str, include_nulls: bool = False) -> "DatasetDict":
"""Casts the given column as `datasets.features.ClassLabel` and updates the tables.

Expand Down Expand Up @@ -1911,6 +1947,32 @@ def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase
"""
return IterableDatasetDict({k: dataset.remove_columns(column_names) for k, dataset in self.items()})

def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDatasetDict":
"""Select one or several column(s) in the dataset and the features
associated to them. The selection is done on-the-fly on the examples
when iterating over the dataset. The selection is applied to all the
datasets of the dataset dictionary.


Args:
column_names (`Union[str, List[str]]`):
Name of the column(s) to keep.

Returns:
[`IterableDatasetDict`]: A copy of the dataset object with only selected columns.

Example:

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", streaming=True)
>>> ds = ds.select("text")
>>> next(iter(ds["train"]))
{'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}
```
"""
return IterableDatasetDict({k: dataset.select_columns(column_names) for k, dataset in self.items()})

def cast_column(self, column: str, feature: FeatureType) -> "IterableDatasetDict":
"""Cast column to feature for decoding.
The type casting is applied to all the datasets of the dataset dictionary.
Expand Down
71 changes: 71 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable":
)


class SelectColumnsIterable(_BaseExamplesIterable):
def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: List[str]):
self.ex_iterable = ex_iterable
self.column_names = column_names

def __iter__(self):
for idx, row in self.ex_iterable:
yield idx, {c: row[c] for c in self.column_names}

def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable":
return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names)

def shard_data_sources(self, shard_indices: List[int]) -> "SelectColumnsIterable":
return SelectColumnsIterable(self.ex_iterable.shard_data_sources(shard_indices), self.column_names)

@property
def n_shards(self) -> int:
return self.ex_iterable.n_shards


class StepExamplesIterable(_BaseExamplesIterable):
def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int):
self.ex_iterable = ex_iterable
Expand Down Expand Up @@ -1494,6 +1514,57 @@ def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase
del ds_iterable._info.features[col]
return ds_iterable

def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDataset":
"""Select one or several column(s) in the dataset and the features
associated to them. The selection is done on-the-fly on the examples
when iterating over the dataset.


Args:
column_names (`Union[str, List[str]]`):
Name of the column(s) to select.

Returns:
`IterableDataset`: A copy of the dataset object with selected columns.

Example:

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train", streaming=True)
>>> next(iter(ds))
{'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .', 'label': 1}
>>> ds = ds.select_columns("text")
>>> next(iter(ds))
{'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}
```
"""
if isinstance(column_names, str):
column_names = [column_names]

if self._info:
info = copy.deepcopy(self._info)
if self._info.features is not None:
for column_name in column_names:
if column_name not in self._info.features:
raise ValueError(
f"Column name {column_name} not in the "
"dataset. Columns in the dataset: "
f"{list(self._info.features.keys())}."
)
info.features = Features({c: info.features[c] for c in column_names})

ex_iterable = SelectColumnsIterable(self._ex_iterable, column_names)
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
format_type=self._format_type,
shuffling=self._shuffling,
distributed=self._distributed,
token_per_repo_id=self._token_per_repo_id,
)

def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset":
"""Cast column to feature for decoding.

Expand Down
34 changes: 34 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,40 @@ def test_rename_columns(self, in_memory):
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": "new_name", "col_2": "new_name"})

def test_select_columns(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
fingerprint = dset._fingerprint
with dset.select_columns(column_names=[]) as new_dset:
self.assertEqual(new_dset.num_columns, 0)
self.assertListEqual(list(new_dset.column_names), [])
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
fingerprint = dset._fingerprint
with dset.select_columns(column_names="col_1") as new_dset:
self.assertEqual(new_dset.num_columns, 1)
self.assertListEqual(list(new_dset.column_names), ["col_1"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
with dset.select_columns(column_names=["col_1", "col_2", "col_3"]) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
dset._format_columns = ["col_1", "col_2", "col_3"]
with dset.select_columns(column_names=["col_1"]) as new_dset:
self.assertListEqual(new_dset._format_columns, ["col_1"])
self.assertEqual(new_dset.num_columns, 1)
self.assertListEqual(list(new_dset.column_names), ["col_1"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

def test_concatenate(self, in_memory):
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
info1 = DatasetInfo(description="Dataset1")
Expand Down
26 changes: 26 additions & 0 deletions tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,32 @@ def test_rename_column(self):
self.assertListEqual(list(dset_split.column_names), ["new_name", "col_2"])
del dset

def test_select_columns(self):
dset = self._create_dummy_dataset_dict(multiple_columns=True)
dset = dset.select_columns(column_names=[])
for dset_split in dset.values():
self.assertEqual(dset_split.num_columns, 0)

dset = self._create_dummy_dataset_dict(multiple_columns=True)
dset = dset.select_columns(column_names="col_1")
for dset_split in dset.values():
self.assertEqual(dset_split.num_columns, 1)
self.assertListEqual(list(dset_split.column_names), ["col_1"])

dset = self._create_dummy_dataset_dict(multiple_columns=True)
dset = dset.select_columns(column_names=["col_1", "col_2"])
for dset_split in dset.values():
self.assertEqual(dset_split.num_columns, 2)

dset = self._create_dummy_dataset_dict(multiple_columns=True)
for dset_split in dset.values():
dset_split._format_columns = ["col_1", "col_2"]
dset = dset.select_columns(column_names=["col_1"])
for dset_split in dset.values():
self.assertEqual(dset_split.num_columns, 1)
self.assertListEqual(list(dset_split.column_names), ["col_1"])
self.assertListEqual(dset_split._format_columns, ["col_1"])

def test_map(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dsets = self._create_dummy_dataset_dict()
Expand Down
17 changes: 17 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,23 @@ def test_iterable_dataset_remove_columns(dataset_with_several_columns):
assert all(c not in new_dataset.features for c in ["id", "filepath"])


def test_iterable_dataset_select_columns(dataset_with_several_columns):
new_dataset = dataset_with_several_columns.select_columns("id")
assert list(new_dataset) == [
{k: v for k, v in example.items() if k == "id"} for example in dataset_with_several_columns
]
assert new_dataset.features is None
new_dataset = dataset_with_several_columns.select_columns(["id", "filepath"])
assert list(new_dataset) == [
{k: v for k, v in example.items() if k in ("id", "filepath")} for example in dataset_with_several_columns
]
assert new_dataset.features is None
# remove the columns if ds.features was not None
new_dataset = dataset_with_several_columns._resolve_features().select_columns(["id", "filepath"])
assert new_dataset.features is not None
assert all(c in new_dataset.features for c in ["id", "filepath"])


def test_iterable_dataset_cast_column():
ex_iterable = ExamplesIterable(generate_examples_fn, {"label": 10})
features = Features({"id": Value("int64"), "label": Value("int64")})
Expand Down