diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 3d296823efe..f3826c93ad3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1653,6 +1653,12 @@ def map( if fn_kwargs is None: fn_kwargs = {} + if num_proc is not None and num_proc > len(self): + num_proc = len(self) + logger.warning( + f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." + ) + if num_proc is None or num_proc == 1: return self._map_single( function=function, @@ -1673,11 +1679,6 @@ def map( desc=desc, ) else: - if num_proc > len(self): - num_proc = len(self) - logger.warning( - f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." - ) def format_cache_file_name(cache_file_name, rank): sep = cache_file_name.rindex(".") @@ -3256,7 +3257,11 @@ def concatenate_datasets( logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.") # Concatenate tables - table = concat_tables([dset._data for dset in dsets if len(dset._data) > 0], axis=axis) + tables_to_concat = [dset._data for dset in dsets if len(dset._data) > 0] + # There might be no table with data left hence return first empty table + if not tables_to_concat: + return dsets[0] + table = concat_tables(tables_to_concat, axis=axis) if axis == 1: table = update_metadata_with_features(table, None)