diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index fe9f54641e7..37f0a5063c6 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -5,6 +5,7 @@ from collections import Counter from copy import deepcopy from dataclasses import dataclass +from functools import partial from itertools import cycle, islice from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union @@ -30,6 +31,31 @@ Key = Union[int, str] +def identity_func(x): + return x + + +def _rename_columns_fn(example: Dict, column_mapping: Dict[str, str]): + if any(col not in example for col in column_mapping): + raise ValueError( + f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(column_mapping) - set(example)} are not in the dataset." + ) + if any(col in example for col in column_mapping.values()): + raise ValueError( + f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(example) - set(column_mapping.values())} are already in the dataset." + ) + return { + new_column_name: example[original_column_name] + for original_column_name, new_column_name in column_mapping.items() + } + + +def add_column_fn(example: Dict, idx: int, name: str, column: List[Dict]): + if name in example: + raise ValueError(f"Error when adding {name}: column {name} is already in the dataset.") + return {name: column[idx]} + + def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features: pa_table = pa.Table.from_pydict(batch) if try_features is not None: @@ -1626,7 +1652,7 @@ def map( if isinstance(remove_columns, str): remove_columns = [remove_columns] if function is None: - function = lambda x: x # noqa: E731 + function = identity_func if fn_kwargs is None: fn_kwargs = {} ex_iterable = MappedExamplesIterable( @@ -1899,13 +1925,7 @@ def add_column(self, name: str, column: Union[list, np.array]) -> "IterableDatas Returns: `IterableDataset` """ - - def add_column_fn(example, idx): - if name in example: - raise ValueError(f"Error when adding {name}: column {name} is already in the dataset.") - return {name: column[idx]} - - return self.map(add_column_fn, with_indices=True) + return self.map(partial(add_column_fn, name=name, column=column), with_indices=True) def rename_column(self, original_column_name: str, new_column_name: str) -> "IterableDataset": """ @@ -1935,28 +1955,7 @@ def rename_column(self, original_column_name: str, new_column_name: str) -> "Ite 'movie_review': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'} ``` """ - - def rename_column_fn(example): - if original_column_name not in example: - raise ValueError( - f"Error when renaming {original_column_name} to {new_column_name}: column {original_column_name} is not in the dataset." - ) - if new_column_name in example: - raise ValueError( - f"Error when renaming {original_column_name} to {new_column_name}: column {new_column_name} is already in the dataset." - ) - return {new_column_name: example[original_column_name]} - - original_features = self._info.features.copy() if self._info.features else None - ds_iterable = self.map(rename_column_fn, remove_columns=[original_column_name]) - if original_features is not None: - ds_iterable._info.features = Features( - { - new_column_name if col == original_column_name else col: feature - for col, feature in original_features.items() - } - ) - return ds_iterable + return self.rename_columns({original_column_name: new_column_name}) def rename_columns(self, column_mapping: Dict[str, str]) -> "IterableDataset": """ @@ -1970,22 +1969,10 @@ def rename_columns(self, column_mapping: Dict[str, str]) -> "IterableDataset": `IterableDataset`: A copy of the dataset with renamed columns """ - def rename_columns_fn(example): - if any(col not in example for col in column_mapping): - raise ValueError( - f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(column_mapping) - set(example)} are not in the dataset." - ) - if any(col in example for col in column_mapping.values()): - raise ValueError( - f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(example) - set(column_mapping.values())} are already in the dataset." - ) - return { - new_column_name: example[original_column_name] - for original_column_name, new_column_name in column_mapping.items() - } - original_features = self._info.features.copy() if self._info.features else None - ds_iterable = self.map(rename_columns_fn, remove_columns=list(column_mapping)) + ds_iterable = self.map( + partial(_rename_columns_fn, column_mapping=column_mapping), remove_columns=list(column_mapping) + ) if original_features is not None: ds_iterable._info.features = Features( { diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index c839c08cebe..7d32265b87f 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1,3 +1,4 @@ +import pickle from copy import deepcopy from itertools import chain, islice @@ -1900,17 +1901,17 @@ def add_one_numpy(example): assert isinstance(next(dataset.iter(batch_size=3))["id"], list) -@pytest.mark.parametrize("n_shards1, nshards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) -def test_interleave_dataset_with_sharding(n_shards1, nshards2, num_workers): +@pytest.mark.parametrize("n_shards1, n_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) +def test_interleave_dataset_with_sharding(n_shards1, n_shards2, num_workers): from torch.utils.data import DataLoader ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(n_shards1)]}) dataset1 = IterableDataset(ex_iterable1).with_format("torch") - ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(nshards2)]}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(n_shards2)]}) dataset2 = IterableDataset(ex_iterable2).with_format("torch") dataset_merged = interleave_datasets([dataset1, dataset2], stopping_strategy="first_exhausted") - assert dataset_merged.n_shards == min(n_shards1, nshards2) + assert dataset_merged.n_shards == min(n_shards1, n_shards2) dataloader = DataLoader(dataset_merged, batch_size=None, num_workers=num_workers) result = list(dataloader) expected_length = 2 * min( @@ -1919,3 +1920,30 @@ def test_interleave_dataset_with_sharding(n_shards1, nshards2, num_workers): # some samples may be missing because the stopping strategy is applied per process assert expected_length - num_workers <= len(result) <= expected_length assert len(result) == len({str(x) for x in result}) + + +def filter_func(batch): + return batch["id"] == 4 + + +def map_func(batch): + batch["id"] *= 2 + return batch + + +def test_pickle_after_many_transforms(dataset_with_several_columns): + dataset = dataset_with_several_columns + dataset = dataset.remove_columns(["filepath"]) + dataset = dataset.take(5) + dataset = dataset.map(map_func) + dataset = dataset.shuffle() + dataset = dataset.skip(1) + dataset = dataset.filter(filter_func) + dataset = dataset.add_column("additional_col", ["something"]) + dataset = dataset.rename_column("metadata", "metadata1") + dataset = dataset.rename_columns({"id": "id1", "metadata1": "metadata2"}) + dataset = dataset.select_columns(["id1", "additional_col"]) + + unpickled_dataset = pickle.loads(pickle.dumps(dataset)) + + assert list(unpickled_dataset) == list(dataset)