Skip to content

Commit 418b3bc

Browse files
authored
Add fn_kwargs param to IterableDataset.map (#4975)
1 parent 341b555 commit 418b3bc

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

src/datasets/iterable_dataset.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
batch_size: int = 1000,
335335
drop_last_batch: bool = False,
336336
remove_columns: Optional[List[str]] = None,
337+
fn_kwargs: Optional[dict] = None,
337338
):
338339
self.ex_iterable = ex_iterable
339340
self.function = function
@@ -343,6 +344,7 @@ def __init__(
343344
self.remove_columns = remove_columns
344345
self.with_indices = with_indices
345346
self.input_columns = input_columns
347+
self.fn_kwargs = fn_kwargs or {}
346348

347349
def __iter__(self):
348350
iterator = iter(self.ex_iterable)
@@ -363,7 +365,7 @@ def __iter__(self):
363365
if self.with_indices:
364366
function_args.append([current_idx + i for i in range(len(key_examples_list))])
365367
transformed_batch = dict(batch) # this will be updated with the function output
366-
transformed_batch.update(self.function(*function_args))
368+
transformed_batch.update(self.function(*function_args, **self.fn_kwargs))
367369
# then remove the unwanted columns
368370
if self.remove_columns:
369371
for c in self.remove_columns:
@@ -396,7 +398,7 @@ def __iter__(self):
396398
if self.with_indices:
397399
function_args.append(current_idx)
398400
transformed_example = dict(example) # this will be updated with the function output
399-
transformed_example.update(self.function(*function_args))
401+
transformed_example.update(self.function(*function_args, **self.fn_kwargs))
400402
# then we remove the unwanted columns
401403
if self.remove_columns:
402404
for c in self.remove_columns:
@@ -414,6 +416,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
414416
batched=self.batched,
415417
batch_size=self.batch_size,
416418
remove_columns=self.remove_columns,
419+
fn_kwargs=self.fn_kwargs,
417420
)
418421

419422
def shard_data_sources(self, shard_idx: int) -> "MappedExamplesIterable":
@@ -426,6 +429,7 @@ def shard_data_sources(self, shard_idx: int) -> "MappedExamplesIterable":
426429
batched=self.batched,
427430
batch_size=self.batch_size,
428431
remove_columns=self.remove_columns,
432+
fn_kwargs=self.fn_kwargs,
429433
)
430434

431435
@property
@@ -759,6 +763,7 @@ def map(
759763
batch_size: int = 1000,
760764
drop_last_batch: bool = False,
761765
remove_columns: Optional[Union[str, List[str]]] = None,
766+
fn_kwargs: Optional[dict] = None,
762767
) -> "IterableDataset":
763768
"""
764769
Apply a function to all the examples in the iterable dataset (individually or in batches) and update them.
@@ -797,6 +802,7 @@ def map(
797802
remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
798803
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
799804
columns with names in `remove_columns`, these columns will be kept.
805+
fn_kwargs (:obj:`Dict`, optional, default `None`): Keyword arguments to be passed to `function`.
800806
801807
Example:
802808
@@ -821,6 +827,8 @@ def map(
821827
remove_columns = [remove_columns]
822828
if function is None:
823829
function = lambda x: x # noqa: E731
830+
if fn_kwargs is None:
831+
fn_kwargs = {}
824832
info = self._info.copy()
825833
info.features = None
826834
ex_iterable = MappedExamplesIterable(
@@ -834,6 +842,7 @@ def map(
834842
batch_size=batch_size,
835843
drop_last_batch=drop_last_batch,
836844
remove_columns=remove_columns,
845+
fn_kwargs=fn_kwargs,
837846
)
838847
return iterable_dataset(
839848
ex_iterable=ex_iterable,

tests/test_iterable_dataset.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,39 @@ def test_mapped_examples_iterable_remove_columns(n, func, batch_size, remove_col
328328
assert list(x for _, x in ex_iterable) == expected
329329

330330

331+
@pytest.mark.parametrize(
332+
"n, func, batch_size, fn_kwargs",
333+
[
334+
(3, lambda x, y=0: {"id+y": x["id"] + y}, None, None),
335+
(3, lambda x, y=0: {"id+y": x["id"] + y}, None, {"y": 3}),
336+
(25, lambda x, y=0: {"id+y": [i + y for i in x["id"]]}, 10, {"y": 3}),
337+
],
338+
)
339+
def test_mapped_examples_iterable_fn_kwargs(n, func, batch_size, fn_kwargs):
340+
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
341+
ex_iterable = MappedExamplesIterable(
342+
base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, fn_kwargs=fn_kwargs
343+
)
344+
all_examples = [x for _, x in generate_examples_fn(n=n)]
345+
if fn_kwargs is None:
346+
fn_kwargs = {}
347+
if batch_size is None:
348+
expected = [{**x, **func(x, **fn_kwargs)} for x in all_examples]
349+
else:
350+
# For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
351+
all_transformed_examples = []
352+
for batch_offset in range(0, len(all_examples), batch_size):
353+
examples = all_examples[batch_offset : batch_offset + batch_size]
354+
batch = _examples_to_batch(examples)
355+
transformed_batch = func(batch, **fn_kwargs)
356+
all_transformed_examples.extend(_batch_to_examples(transformed_batch))
357+
expected = _examples_to_batch(all_examples)
358+
expected.update(_examples_to_batch(all_transformed_examples))
359+
expected = list(_batch_to_examples(expected))
360+
assert next(iter(ex_iterable))[1] == expected[0]
361+
assert list(x for _, x in ex_iterable) == expected
362+
363+
331364
@pytest.mark.parametrize(
332365
"n, func, batch_size, input_columns",
333366
[

0 commit comments

Comments
 (0)