diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 10dc59ecb5a..2d0d3b235a7 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1509,6 +1509,7 @@ def map( num_proc: Optional[int] = None, suffix_template: str = "_{rank:05d}_of_{num_proc:05d}", new_fingerprint: Optional[str] = None, + desc: Optional[str] = None, ) -> "Dataset": """Apply a function to all the elements in the table (individually or in batches) and update the table (if function does update examples). @@ -1554,6 +1555,7 @@ def map( rank=1 and num_proc=4, the resulting file would be "processed_00001_of_00004.arrow" for the default suffix. new_fingerprint (`Optional[str]`, default `None`): the new fingerprint of the dataset after transform. If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments. + desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples. """ assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0." @@ -1598,6 +1600,7 @@ def map( disable_nullable=disable_nullable, fn_kwargs=fn_kwargs, new_fingerprint=new_fingerprint, + desc=desc, ) else: @@ -1649,6 +1652,7 @@ def format_cache_file_name(cache_file_name, rank): fn_kwargs=fn_kwargs, rank=rank, offset=sum(len(s) for s in shards[:rank]), + desc=desc, ) for rank in range(num_proc) ] @@ -1662,7 +1666,7 @@ def format_cache_file_name(cache_file_name, rank): return result @transmit_format - @fingerprint_transform(inplace=False, ignore_kwargs=["load_from_cache_file", "cache_file_name"]) + @fingerprint_transform(inplace=False, ignore_kwargs=["load_from_cache_file", "cache_file_name", "desc"]) def _map_single( self, function: Optional[Callable] = None, @@ -1682,6 +1686,7 @@ def _map_single( new_fingerprint: Optional[str] = None, rank: Optional[int] = None, offset: int = 0, + desc: Optional[str] = None, ) -> "Dataset": """Apply a function to all the elements in the table (individually or in batches) and update the table (if function does update examples). @@ -1720,6 +1725,7 @@ def _map_single( If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments rank: (`Optional[int]`, defaults to `None`): If specified, this is the process rank when doing multiprocessing offset: (:obj:`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True` + desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples. """ assert ( not keep_in_memory or cache_file_name is None @@ -1895,7 +1901,7 @@ def init_buffer_and_writer(): # Loop over single examples or batches and write to buffer/file if examples are to be updated pbar_iterable = input_dataset if not batched else range(0, len(input_dataset), batch_size) pbar_unit = "ex" if not batched else "ba" - pbar_desc = "#" + str(rank) if rank is not None else None + pbar_desc = (desc or "") + " #" + str(rank) if rank is not None else desc pbar = tqdm(pbar_iterable, disable=not_verbose, position=rank, unit=pbar_unit, desc=pbar_desc) if not batched: for i, example in enumerate(pbar):