Skip to content

Commit 83e0f8f

Browse files
bruno-haysBruno Hayslhoestq
authored andcommitted
Fix multiprocessing with spawn in iterable datasets (#6165)
* fixed remove columns and rename columns * fixed rename column, removed code duplication * linting * typo * added pickle test * fixed rename column not being picklable * linting * added verif that the pickling process does not change the data --------- Co-authored-by: Bruno Hays <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 941af81 commit 83e0f8f

File tree

2 files changed

+64
-49
lines changed

2 files changed

+64
-49
lines changed

src/datasets/iterable_dataset.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import Counter
66
from copy import deepcopy
77
from dataclasses import dataclass
8+
from functools import partial
89
from itertools import cycle, islice
910
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
1011

@@ -30,6 +31,31 @@
3031
Key = Union[int, str]
3132

3233

34+
def identity_func(x):
35+
return x
36+
37+
38+
def _rename_columns_fn(example: Dict, column_mapping: Dict[str, str]):
39+
if any(col not in example for col in column_mapping):
40+
raise ValueError(
41+
f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(column_mapping) - set(example)} are not in the dataset."
42+
)
43+
if any(col in example for col in column_mapping.values()):
44+
raise ValueError(
45+
f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(example) - set(column_mapping.values())} are already in the dataset."
46+
)
47+
return {
48+
new_column_name: example[original_column_name]
49+
for original_column_name, new_column_name in column_mapping.items()
50+
}
51+
52+
53+
def add_column_fn(example: Dict, idx: int, name: str, column: List[Dict]):
54+
if name in example:
55+
raise ValueError(f"Error when adding {name}: column {name} is already in the dataset.")
56+
return {name: column[idx]}
57+
58+
3359
def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features:
3460
pa_table = pa.Table.from_pydict(batch)
3561
if try_features is not None:
@@ -1626,7 +1652,7 @@ def map(
16261652
if isinstance(remove_columns, str):
16271653
remove_columns = [remove_columns]
16281654
if function is None:
1629-
function = lambda x: x # noqa: E731
1655+
function = identity_func
16301656
if fn_kwargs is None:
16311657
fn_kwargs = {}
16321658
ex_iterable = MappedExamplesIterable(
@@ -1899,13 +1925,7 @@ def add_column(self, name: str, column: Union[list, np.array]) -> "IterableDatas
18991925
Returns:
19001926
`IterableDataset`
19011927
"""
1902-
1903-
def add_column_fn(example, idx):
1904-
if name in example:
1905-
raise ValueError(f"Error when adding {name}: column {name} is already in the dataset.")
1906-
return {name: column[idx]}
1907-
1908-
return self.map(add_column_fn, with_indices=True)
1928+
return self.map(partial(add_column_fn, name=name, column=column), with_indices=True)
19091929

19101930
def rename_column(self, original_column_name: str, new_column_name: str) -> "IterableDataset":
19111931
"""
@@ -1935,28 +1955,7 @@ def rename_column(self, original_column_name: str, new_column_name: str) -> "Ite
19351955
'movie_review': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}
19361956
```
19371957
"""
1938-
1939-
def rename_column_fn(example):
1940-
if original_column_name not in example:
1941-
raise ValueError(
1942-
f"Error when renaming {original_column_name} to {new_column_name}: column {original_column_name} is not in the dataset."
1943-
)
1944-
if new_column_name in example:
1945-
raise ValueError(
1946-
f"Error when renaming {original_column_name} to {new_column_name}: column {new_column_name} is already in the dataset."
1947-
)
1948-
return {new_column_name: example[original_column_name]}
1949-
1950-
original_features = self._info.features.copy() if self._info.features else None
1951-
ds_iterable = self.map(rename_column_fn, remove_columns=[original_column_name])
1952-
if original_features is not None:
1953-
ds_iterable._info.features = Features(
1954-
{
1955-
new_column_name if col == original_column_name else col: feature
1956-
for col, feature in original_features.items()
1957-
}
1958-
)
1959-
return ds_iterable
1958+
return self.rename_columns({original_column_name: new_column_name})
19601959

19611960
def rename_columns(self, column_mapping: Dict[str, str]) -> "IterableDataset":
19621961
"""
@@ -1970,22 +1969,10 @@ def rename_columns(self, column_mapping: Dict[str, str]) -> "IterableDataset":
19701969
`IterableDataset`: A copy of the dataset with renamed columns
19711970
"""
19721971

1973-
def rename_columns_fn(example):
1974-
if any(col not in example for col in column_mapping):
1975-
raise ValueError(
1976-
f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(column_mapping) - set(example)} are not in the dataset."
1977-
)
1978-
if any(col in example for col in column_mapping.values()):
1979-
raise ValueError(
1980-
f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(example) - set(column_mapping.values())} are already in the dataset."
1981-
)
1982-
return {
1983-
new_column_name: example[original_column_name]
1984-
for original_column_name, new_column_name in column_mapping.items()
1985-
}
1986-
19871972
original_features = self._info.features.copy() if self._info.features else None
1988-
ds_iterable = self.map(rename_columns_fn, remove_columns=list(column_mapping))
1973+
ds_iterable = self.map(
1974+
partial(_rename_columns_fn, column_mapping=column_mapping), remove_columns=list(column_mapping)
1975+
)
19891976
if original_features is not None:
19901977
ds_iterable._info.features = Features(
19911978
{

tests/test_iterable_dataset.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
from copy import deepcopy
23
from itertools import chain, islice
34

@@ -1900,17 +1901,17 @@ def add_one_numpy(example):
19001901
assert isinstance(next(dataset.iter(batch_size=3))["id"], list)
19011902

19021903

1903-
@pytest.mark.parametrize("n_shards1, nshards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)])
1904-
def test_interleave_dataset_with_sharding(n_shards1, nshards2, num_workers):
1904+
@pytest.mark.parametrize("n_shards1, n_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)])
1905+
def test_interleave_dataset_with_sharding(n_shards1, n_shards2, num_workers):
19051906
from torch.utils.data import DataLoader
19061907

19071908
ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(n_shards1)]})
19081909
dataset1 = IterableDataset(ex_iterable1).with_format("torch")
1909-
ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(nshards2)]})
1910+
ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(n_shards2)]})
19101911
dataset2 = IterableDataset(ex_iterable2).with_format("torch")
19111912

19121913
dataset_merged = interleave_datasets([dataset1, dataset2], stopping_strategy="first_exhausted")
1913-
assert dataset_merged.n_shards == min(n_shards1, nshards2)
1914+
assert dataset_merged.n_shards == min(n_shards1, n_shards2)
19141915
dataloader = DataLoader(dataset_merged, batch_size=None, num_workers=num_workers)
19151916
result = list(dataloader)
19161917
expected_length = 2 * min(
@@ -1919,3 +1920,30 @@ def test_interleave_dataset_with_sharding(n_shards1, nshards2, num_workers):
19191920
# some samples may be missing because the stopping strategy is applied per process
19201921
assert expected_length - num_workers <= len(result) <= expected_length
19211922
assert len(result) == len({str(x) for x in result})
1923+
1924+
1925+
def filter_func(batch):
1926+
return batch["id"] == 4
1927+
1928+
1929+
def map_func(batch):
1930+
batch["id"] *= 2
1931+
return batch
1932+
1933+
1934+
def test_pickle_after_many_transforms(dataset_with_several_columns):
1935+
dataset = dataset_with_several_columns
1936+
dataset = dataset.remove_columns(["filepath"])
1937+
dataset = dataset.take(5)
1938+
dataset = dataset.map(map_func)
1939+
dataset = dataset.shuffle()
1940+
dataset = dataset.skip(1)
1941+
dataset = dataset.filter(filter_func)
1942+
dataset = dataset.add_column("additional_col", ["something"])
1943+
dataset = dataset.rename_column("metadata", "metadata1")
1944+
dataset = dataset.rename_columns({"id": "id1", "metadata1": "metadata2"})
1945+
dataset = dataset.select_columns(["id1", "additional_col"])
1946+
1947+
unpickled_dataset = pickle.loads(pickle.dumps(dataset))
1948+
1949+
assert list(unpickled_dataset) == list(dataset)

0 commit comments

Comments
 (0)