Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/processing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,27 @@ To disable caching you can run:
>>> set_caching_enabled(False)

You can also query the current status of the caching with :func:`datasets.is_caching_enabled`:

Mapping in a distributed setting
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In a distributed setting, you may use caching and a :func:`torch.distributed.barrier` to make sure that only the main process performs the mapping, while the other ones load its results. This avoids duplicating work between all the processes, or worse, requesting more CPUs than your system can handle. For example:

.. code-block::

>>> from datasets import Dataset
>>> import torch.distributed
>>>
>>> dataset1 = Dataset.from_dict({"a": [0, 1, 2]})
>>>
>>> if training_args.local_rank > 0:
... print("Waiting for main process to perform the mapping")
... torch.distributed.barrier()
>>>
>>> dataset2 = dataset1.map(lambda x: {"a": x["a"] + 1})
>>>
>>> if training_args.local_rank == 0:
... print("Loading results from main process")
... torch.distributed.barrier()

When it encounters a barrier, each process will stop until all other processes have reached the barrier. The non-main processes reach the barrier first, before the mapping, and wait there. The main processes creates the cache for the processed dataset. It then reaches the barrier, at which point the other processes resume, and load the cache instead of performing the processing themselves.