Skip to content

Commit acc993a

Browse files
committed
fix doc preblem
1 parent 4acc87b commit acc993a

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

python/paddle/distributed/collective.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def get_group(id=0):
142142
Get group instance by group id.
143143
144144
Args:
145-
id (int): the group id
145+
id (int): the group id. Default value is 0.
146146
147147
Returns:
148148
Group: the group instance.
@@ -163,26 +163,24 @@ def get_group(id=0):
163163
def new_group(ranks=None, backend=None):
164164
"""
165165
166-
Creates a new distributed comminication group.
166+
Creates a new distributed communication group.
167167
168168
Args:
169-
ranks (list): The global ranks of group members, list as sorted.
169+
ranks (list): The global ranks of group members.
170170
backend (str): The backend used to create group, only nccl is supported now.
171171
172172
Returns:
173-
Group: The group instance. Nerver return None.
173+
Group: The group instance.
174174
175175
Examples:
176176
.. code-block:: python
177177
178-
import numpy as np
179178
import paddle
180179
181180
paddle.distributed.init_parallel_env()
182-
tindata = np.random.random([10, 1000]).astype('float32')
183-
tindata = paddle.to_tensor(tindata)
184-
gid = paddle.distributed.new_group([2,4,6])
185-
paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False)
181+
tindata = paddle.randn(shape=[2, 3])
182+
gp = paddle.distributed.new_group([2,4,6])
183+
paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
186184
187185
"""
188186

@@ -221,7 +219,7 @@ def new_group(ranks=None, backend=None):
221219
place = core.CUDAPlace(genv.device_id)
222220
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
223221
else:
224-
assert False
222+
assert False, ("no cuda device found")
225223

226224
return gp
227225

@@ -234,22 +232,19 @@ def wait(tensor, group=None, use_calc_stream=True):
234232
Args:
235233
tensor (Tensor): The Tensor used before sync.
236234
group (Group): The Group instance to perform sync.
237-
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
238-
default to False.
235+
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
236+
Default to True.
239237
240238
Returns:
241239
None.
242240
243241
Examples:
244242
.. code-block:: python
245243
246-
247-
import numpy as np
248244
import paddle
249245
250246
paddle.distributed.init_parallel_env()
251-
tindata = np.random.random([10, 1000]).astype('float32')
252-
tindata = paddle.to_tensor(tindata)
247+
tindata = paddle.randn(shape=[2, 3])
253248
paddle.distributed.all_reduce(tindata, use_calc_stream=True)
254249
paddle.distributed.wait(tindata)
255250
@@ -306,8 +301,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
306301
should be float16, float32, float64, int32 or int64.
307302
src (int): The source rank.
308303
group (Group): The group instance return by new_group or None for global default group.
309-
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
310-
default to True.
304+
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
305+
Default to True.
311306
312307
Returns:
313308
None.
@@ -339,6 +334,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
339334

340335
ring_id = 0 if group is None else group.id
341336
gsrc = src if group is None else group.get_group_rank(src)
337+
assert gsrc >= 0, ("src rank out of group, need global rank")
342338

343339
if in_dygraph_mode():
344340
return core.ops.c_broadcast(tensor, tensor, 'root', gsrc,
@@ -370,10 +366,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
370366
Args:
371367
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
372368
should be float16, float32, float64, int32 or int64.
373-
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
369+
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
374370
group (Group): The group instance return by new_group or None for global default group.
375-
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
376-
default to True.
371+
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
372+
Default to True.
377373
378374
Returns:
379375
None.
@@ -453,10 +449,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
453449
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
454450
should be float16, float32, float64, int32 or int64.
455451
dst (int): The destination rank id.
456-
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
452+
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
457453
group (Group): The group instance return by new_group or None for global default group.
458-
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
459-
default to True.
454+
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
455+
Default to True.
460456
461457
Returns:
462458
None.
@@ -487,6 +483,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
487483

488484
ring_id = 0 if group is None else group.id
489485
gdst = dst if group is None else group.get_group_rank(dst)
486+
assert gdst >= 0, ("dst rank out of group, need global rank")
490487

491488
if in_dygraph_mode():
492489
if op == ReduceOp.SUM:
@@ -548,8 +545,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
548545
tensor (Tensor): The Tensor to send. Its data type
549546
should be float16, float32, float64, int32 or int64.
550547
group (Group): The group instance return by new_group or None for global default group.
551-
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
552-
default to True.
548+
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
549+
Default to True.
553550
554551
Returns:
555552
None.
@@ -624,11 +621,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
624621
tensor (Tensor): The output Tensor. Its data type
625622
should be float16, float32, float64, int32 or int64.
626623
tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
627-
should be float16, float32, float64, int32 or int64.
628-
src (int): The source rank id.
624+
should be float16, float32, float64, int32 or int64. Default value is None.
625+
src (int): The source rank id. Default value is 0.
629626
group (Group): The group instance return by new_group or None for global default group.
630-
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
631-
default to True.
627+
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
628+
Default to True.
632629
633630
Returns:
634631
None.
@@ -664,6 +661,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
664661

665662
ring_id = 0 if group is None else group.id
666663
gsrc = src if group is None else group.get_group_rank(src)
664+
assert gsrc >= 0, ("src rank out of group, need global rank")
667665
rank = _get_global_group().rank if group is None else group.rank
668666
nranks = _get_global_group().nranks if group is None else group.nranks
669667

0 commit comments

Comments
 (0)