@@ -14,7 +14,7 @@ class ParallelBackendConfig:
1414
1515
1616@experimental
17- def parallel_map (function , iterable , num_proc , types , disable_tqdm , desc , single_map_nested_func ):
17+ def parallel_map (function , iterable , num_proc , batched , batch_size , types , disable_tqdm , desc , single_map_nested_func ):
1818 """
1919 **Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either
2020 multiprocessing.Pool or joblib for parallelization.
@@ -32,21 +32,25 @@ def parallel_map(function, iterable, num_proc, types, disable_tqdm, desc, single
3232 """
3333 if ParallelBackendConfig .backend_name is None :
3434 return _map_with_multiprocessing_pool (
35- function , iterable , num_proc , types , disable_tqdm , desc , single_map_nested_func
35+ function , iterable , num_proc , batched , batch_size , types , disable_tqdm , desc , single_map_nested_func
3636 )
3737
38- return _map_with_joblib (function , iterable , num_proc , types , disable_tqdm , desc , single_map_nested_func )
38+ return _map_with_joblib (
39+ function , iterable , num_proc , batched , batch_size , types , disable_tqdm , desc , single_map_nested_func
40+ )
3941
4042
41- def _map_with_multiprocessing_pool (function , iterable , num_proc , types , disable_tqdm , desc , single_map_nested_func ):
43+ def _map_with_multiprocessing_pool (
44+ function , iterable , num_proc , batched , batch_size , types , disable_tqdm , desc , single_map_nested_func
45+ ):
4246 num_proc = num_proc if num_proc <= len (iterable ) else len (iterable )
4347 split_kwds = [] # We organize the splits ourselve (contiguous splits)
4448 for index in range (num_proc ):
4549 div = len (iterable ) // num_proc
4650 mod = len (iterable ) % num_proc
4751 start = div * index + min (index , mod )
4852 end = start + div + (1 if index < mod else 0 )
49- split_kwds .append ((function , iterable [start :end ], types , index , disable_tqdm , desc ))
53+ split_kwds .append ((function , iterable [start :end ], batched , batch_size , types , index , disable_tqdm , desc ))
5054
5155 if len (iterable ) != sum (len (i [1 ]) for i in split_kwds ):
5256 raise ValueError (
@@ -70,14 +74,17 @@ def _map_with_multiprocessing_pool(function, iterable, num_proc, types, disable_
7074 return mapped
7175
7276
73- def _map_with_joblib (function , iterable , num_proc , types , disable_tqdm , desc , single_map_nested_func ):
77+ def _map_with_joblib (
78+ function , iterable , num_proc , batched , batch_size , types , disable_tqdm , desc , single_map_nested_func
79+ ):
7480 # progress bar is not yet supported for _map_with_joblib, because tqdm couldn't accurately be applied to joblib,
7581 # and it requires monkey-patching joblib internal classes which is subject to change
7682 import joblib
7783
7884 with joblib .parallel_backend (ParallelBackendConfig .backend_name , n_jobs = num_proc ):
7985 return joblib .Parallel ()(
80- joblib .delayed (single_map_nested_func )((function , obj , types , None , True , None )) for obj in iterable
86+ joblib .delayed (single_map_nested_func )((function , obj , batched , batch_size , types , None , True , None ))
87+ for obj in iterable
8188 )
8289
8390
0 commit comments