Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,7 @@ def map(
num_proc: Optional[int] = None,
suffix_template: str = "_{rank:05d}_of_{num_proc:05d}",
new_fingerprint: Optional[str] = None,
desc: Optional[str] = None,
) -> "Dataset":
"""Apply a function to all the elements in the table (individually or in batches)
and update the table (if function does update examples).
Expand Down Expand Up @@ -1554,6 +1555,7 @@ def map(
rank=1 and num_proc=4, the resulting file would be "processed_00001_of_00004.arrow" for the default suffix.
new_fingerprint (`Optional[str]`, default `None`): the new fingerprint of the dataset after transform.
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
"""
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."

Expand Down Expand Up @@ -1598,6 +1600,7 @@ def map(
disable_nullable=disable_nullable,
fn_kwargs=fn_kwargs,
new_fingerprint=new_fingerprint,
desc=desc,
)
else:

Expand Down Expand Up @@ -1649,6 +1652,7 @@ def format_cache_file_name(cache_file_name, rank):
fn_kwargs=fn_kwargs,
rank=rank,
offset=sum(len(s) for s in shards[:rank]),
desc=desc,
)
for rank in range(num_proc)
]
Expand All @@ -1662,7 +1666,7 @@ def format_cache_file_name(cache_file_name, rank):
return result

@transmit_format
@fingerprint_transform(inplace=False, ignore_kwargs=["load_from_cache_file", "cache_file_name"])
@fingerprint_transform(inplace=False, ignore_kwargs=["load_from_cache_file", "cache_file_name", "desc"])
def _map_single(
self,
function: Optional[Callable] = None,
Expand All @@ -1682,6 +1686,7 @@ def _map_single(
new_fingerprint: Optional[str] = None,
rank: Optional[int] = None,
offset: int = 0,
desc: Optional[str] = None,
) -> "Dataset":
"""Apply a function to all the elements in the table (individually or in batches)
and update the table (if function does update examples).
Expand Down Expand Up @@ -1720,6 +1725,7 @@ def _map_single(
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments
rank: (`Optional[int]`, defaults to `None`): If specified, this is the process rank when doing multiprocessing
offset: (:obj:`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True`
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
"""
assert (
not keep_in_memory or cache_file_name is None
Expand Down Expand Up @@ -1895,7 +1901,7 @@ def init_buffer_and_writer():
# Loop over single examples or batches and write to buffer/file if examples are to be updated
pbar_iterable = input_dataset if not batched else range(0, len(input_dataset), batch_size)
pbar_unit = "ex" if not batched else "ba"
pbar_desc = "#" + str(rank) if rank is not None else None
pbar_desc = (desc or "") + " #" + str(rank) if rank is not None else desc
pbar = tqdm(pbar_iterable, disable=not_verbose, position=rank, unit=pbar_unit, desc=pbar_desc)
if not batched:
for i, example in enumerate(pbar):
Expand Down