@@ -142,7 +142,7 @@ def get_group(id=0):
142142 Get group instance by group id.
143143
144144 Args:
145- id (int): the group id
145+ id (int): the group id. Default value is 0.
146146
147147 Returns:
148148 Group: the group instance.
@@ -163,26 +163,24 @@ def get_group(id=0):
163163def new_group (ranks = None , backend = None ):
164164 """
165165
166- Creates a new distributed comminication group.
166+ Creates a new distributed communication group.
167167
168168 Args:
169- ranks (list): The global ranks of group members, list as sorted .
169+ ranks (list): The global ranks of group members.
170170 backend (str): The backend used to create group, only nccl is supported now.
171171
172172 Returns:
173- Group: The group instance. Nerver return None.
173+ Group: The group instance.
174174
175175 Examples:
176176 .. code-block:: python
177177
178- import numpy as np
179178 import paddle
180179
181180 paddle.distributed.init_parallel_env()
182- tindata = np.random.random([10, 1000]).astype('float32')
183- tindata = paddle.to_tensor(tindata)
184- gid = paddle.distributed.new_group([2,4,6])
185- paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False)
181+ tindata = paddle.randn(shape=[2, 3])
182+ gp = paddle.distributed.new_group([2,4,6])
183+ paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
186184
187185 """
188186
@@ -221,7 +219,7 @@ def new_group(ranks=None, backend=None):
221219 place = core .CUDAPlace (genv .device_id )
222220 core .NCCLParallelContext (strategy , place ).init_with_ring_id (ring_id )
223221 else :
224- assert False
222+ assert False , ( "no cuda device found" )
225223
226224 return gp
227225
@@ -234,22 +232,19 @@ def wait(tensor, group=None, use_calc_stream=True):
234232 Args:
235233 tensor (Tensor): The Tensor used before sync.
236234 group (Group): The Group instance to perform sync.
237- use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
238- default to False .
235+ use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
236+ Default to True .
239237
240238 Returns:
241239 None.
242240
243241 Examples:
244242 .. code-block:: python
245243
246-
247- import numpy as np
248244 import paddle
249245
250246 paddle.distributed.init_parallel_env()
251- tindata = np.random.random([10, 1000]).astype('float32')
252- tindata = paddle.to_tensor(tindata)
247+ tindata = paddle.randn(shape=[2, 3])
253248 paddle.distributed.all_reduce(tindata, use_calc_stream=True)
254249 paddle.distributed.wait(tindata)
255250
@@ -306,8 +301,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
306301 should be float16, float32, float64, int32 or int64.
307302 src (int): The source rank.
308303 group (Group): The group instance return by new_group or None for global default group.
309- use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
310- default to True.
304+ use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
305+ Default to True.
311306
312307 Returns:
313308 None.
@@ -339,6 +334,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
339334
340335 ring_id = 0 if group is None else group .id
341336 gsrc = src if group is None else group .get_group_rank (src )
337+ assert gsrc >= 0 , ("src rank out of group, need global rank" )
342338
343339 if in_dygraph_mode ():
344340 return core .ops .c_broadcast (tensor , tensor , 'root' , gsrc ,
@@ -370,10 +366,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
370366 Args:
371367 tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
372368 should be float16, float32, float64, int32 or int64.
373- op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
369+ op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
374370 group (Group): The group instance return by new_group or None for global default group.
375- use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
376- default to True.
371+ use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
372+ Default to True.
377373
378374 Returns:
379375 None.
@@ -453,10 +449,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
453449 tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
454450 should be float16, float32, float64, int32 or int64.
455451 dst (int): The destination rank id.
456- op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
452+ op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
457453 group (Group): The group instance return by new_group or None for global default group.
458- use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
459- default to True.
454+ use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
455+ Default to True.
460456
461457 Returns:
462458 None.
@@ -487,6 +483,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
487483
488484 ring_id = 0 if group is None else group .id
489485 gdst = dst if group is None else group .get_group_rank (dst )
486+ assert gdst >= 0 , ("dst rank out of group, need global rank" )
490487
491488 if in_dygraph_mode ():
492489 if op == ReduceOp .SUM :
@@ -548,8 +545,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
548545 tensor (Tensor): The Tensor to send. Its data type
549546 should be float16, float32, float64, int32 or int64.
550547 group (Group): The group instance return by new_group or None for global default group.
551- use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
552- default to True.
548+ use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
549+ Default to True.
553550
554551 Returns:
555552 None.
@@ -624,11 +621,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
624621 tensor (Tensor): The output Tensor. Its data type
625622 should be float16, float32, float64, int32 or int64.
626623 tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
627- should be float16, float32, float64, int32 or int64.
628- src (int): The source rank id.
624+ should be float16, float32, float64, int32 or int64. Default value is None.
625+ src (int): The source rank id. Default value is 0.
629626 group (Group): The group instance return by new_group or None for global default group.
630- use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
631- default to True.
627+ use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
628+ Default to True.
632629
633630 Returns:
634631 None.
@@ -664,6 +661,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
664661
665662 ring_id = 0 if group is None else group .id
666663 gsrc = src if group is None else group .get_group_rank (src )
664+ assert gsrc >= 0 , ("src rank out of group, need global rank" )
667665 rank = _get_global_group ().rank if group is None else group .rank
668666 nranks = _get_global_group ().nranks if group is None else group .nranks
669667
0 commit comments