Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_group(id=0):
Get group instance by group id.

Args:
id (int): the group id
id (int): the group id. Default value is 0.

Returns:
Group: the group instance.
Expand All @@ -163,26 +163,24 @@ def get_group(id=0):
def new_group(ranks=None, backend=None):
"""

Creates a new distributed comminication group.
Creates a new distributed communication group.

Args:
ranks (list): The global ranks of group members, list as sorted.
ranks (list): The global ranks of group members.
backend (str): The backend used to create group, only nccl is supported now.

Returns:
Group: The group instance. Nerver return None.
Group: The group instance.

Examples:
.. code-block:: python

import numpy as np
import paddle

paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32')
tindata = paddle.to_tensor(tindata)
gid = paddle.distributed.new_group([2,4,6])
paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False)
tindata = paddle.randn(shape=[2, 3])
gp = paddle.distributed.new_group([2,4,6])
paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)

"""

Expand Down Expand Up @@ -221,7 +219,7 @@ def new_group(ranks=None, backend=None):
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
else:
assert False
assert False, ("no cuda device found")

return gp

Expand All @@ -234,22 +232,19 @@ def wait(tensor, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The Tensor used before sync.
group (Group): The Group instance to perform sync.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to False.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.

Returns:
None.

Examples:
.. code-block:: python


import numpy as np
import paddle

paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32')
tindata = paddle.to_tensor(tindata)
tindata = paddle.randn(shape=[2, 3])
paddle.distributed.all_reduce(tindata, use_calc_stream=True)
paddle.distributed.wait(tindata)

Expand Down Expand Up @@ -306,8 +301,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
should be float16, float32, float64, int32 or int64.
src (int): The source rank.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.

Returns:
None.
Expand Down Expand Up @@ -339,6 +334,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):

ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
assert gsrc >= 0, ("src rank out of group, need global rank")

if in_dygraph_mode():
return core.ops.c_broadcast(tensor, tensor, 'root', gsrc,
Expand Down Expand Up @@ -370,10 +366,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32 or int64.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.

Returns:
None.
Expand Down Expand Up @@ -453,10 +449,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.

Returns:
None.
Expand Down Expand Up @@ -487,6 +483,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):

ring_id = 0 if group is None else group.id
gdst = dst if group is None else group.get_group_rank(dst)
assert gdst >= 0, ("dst rank out of group, need global rank")

if in_dygraph_mode():
if op == ReduceOp.SUM:
Expand Down Expand Up @@ -548,8 +545,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.

Returns:
None.
Expand Down Expand Up @@ -624,11 +621,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32 or int64.
tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank id.
should be float16, float32, float64, int32 or int64. Default value is None.
src (int): The source rank id. Default value is 0.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.

Returns:
None.
Expand Down Expand Up @@ -664,6 +661,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):

ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
assert gsrc >= 0, ("src rank out of group, need global rank")
rank = _get_global_group().rank if group is None else group.rank
nranks = _get_global_group().nranks if group is None else group.nranks

Expand Down