Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,12 @@ def map(
if fn_kwargs is None:
fn_kwargs = {}

if num_proc is not None and num_proc > len(self):
num_proc = len(self)
logger.warning(
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}."
)

Comment on lines +1656 to +1661
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch to avoid multiprocessing with only 1 process, thanks.

if num_proc is None or num_proc == 1:
return self._map_single(
function=function,
Expand All @@ -1673,11 +1679,6 @@ def map(
desc=desc,
)
else:
if num_proc > len(self):
num_proc = len(self)
logger.warning(
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}."
)

def format_cache_file_name(cache_file_name, rank):
sep = cache_file_name.rindex(".")
Expand Down Expand Up @@ -3256,7 +3257,11 @@ def concatenate_datasets(
logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.")

# Concatenate tables
table = concat_tables([dset._data for dset in dsets if len(dset._data) > 0], axis=axis)
tables_to_concat = [dset._data for dset in dsets if len(dset._data) > 0]
# There might be no table with data left hence return first empty table
if not tables_to_concat:
return dsets[0]
table = concat_tables(tables_to_concat, axis=axis)
if axis == 1:
table = update_metadata_with_features(table, None)

Expand Down