@@ -334,6 +334,7 @@ def __init__(
334334 batch_size : int = 1000 ,
335335 drop_last_batch : bool = False ,
336336 remove_columns : Optional [List [str ]] = None ,
337+ fn_kwargs : Optional [dict ] = None ,
337338 ):
338339 self .ex_iterable = ex_iterable
339340 self .function = function
@@ -343,6 +344,7 @@ def __init__(
343344 self .remove_columns = remove_columns
344345 self .with_indices = with_indices
345346 self .input_columns = input_columns
347+ self .fn_kwargs = fn_kwargs or {}
346348
347349 def __iter__ (self ):
348350 iterator = iter (self .ex_iterable )
@@ -363,7 +365,7 @@ def __iter__(self):
363365 if self .with_indices :
364366 function_args .append ([current_idx + i for i in range (len (key_examples_list ))])
365367 transformed_batch = dict (batch ) # this will be updated with the function output
366- transformed_batch .update (self .function (* function_args ))
368+ transformed_batch .update (self .function (* function_args , ** self . fn_kwargs ))
367369 # then remove the unwanted columns
368370 if self .remove_columns :
369371 for c in self .remove_columns :
@@ -396,7 +398,7 @@ def __iter__(self):
396398 if self .with_indices :
397399 function_args .append (current_idx )
398400 transformed_example = dict (example ) # this will be updated with the function output
399- transformed_example .update (self .function (* function_args ))
401+ transformed_example .update (self .function (* function_args , ** self . fn_kwargs ))
400402 # then we remove the unwanted columns
401403 if self .remove_columns :
402404 for c in self .remove_columns :
@@ -414,6 +416,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
414416 batched = self .batched ,
415417 batch_size = self .batch_size ,
416418 remove_columns = self .remove_columns ,
419+ fn_kwargs = self .fn_kwargs ,
417420 )
418421
419422 def shard_data_sources (self , shard_idx : int ) -> "MappedExamplesIterable" :
@@ -426,6 +429,7 @@ def shard_data_sources(self, shard_idx: int) -> "MappedExamplesIterable":
426429 batched = self .batched ,
427430 batch_size = self .batch_size ,
428431 remove_columns = self .remove_columns ,
432+ fn_kwargs = self .fn_kwargs ,
429433 )
430434
431435 @property
@@ -759,6 +763,7 @@ def map(
759763 batch_size : int = 1000 ,
760764 drop_last_batch : bool = False ,
761765 remove_columns : Optional [Union [str , List [str ]]] = None ,
766+ fn_kwargs : Optional [dict ] = None ,
762767 ) -> "IterableDataset" :
763768 """
764769 Apply a function to all the examples in the iterable dataset (individually or in batches) and update them.
@@ -797,6 +802,7 @@ def map(
797802 remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
798803 Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
799804 columns with names in `remove_columns`, these columns will be kept.
805+ fn_kwargs (:obj:`Dict`, optional, default `None`): Keyword arguments to be passed to `function`.
800806
801807 Example:
802808
@@ -821,6 +827,8 @@ def map(
821827 remove_columns = [remove_columns ]
822828 if function is None :
823829 function = lambda x : x # noqa: E731
830+ if fn_kwargs is None :
831+ fn_kwargs = {}
824832 info = self ._info .copy ()
825833 info .features = None
826834 ex_iterable = MappedExamplesIterable (
@@ -834,6 +842,7 @@ def map(
834842 batch_size = batch_size ,
835843 drop_last_batch = drop_last_batch ,
836844 remove_columns = remove_columns ,
845+ fn_kwargs = fn_kwargs ,
837846 )
838847 return iterable_dataset (
839848 ex_iterable = ex_iterable ,
0 commit comments