@@ -207,6 +207,7 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
207207 rank_id ):
208208 src_rank_in_group = src if group is None else group .get_group_rank (src )
209209 if _in_legacy_dygraph ():
210+ assert use_calc_stream
210211 return _legacy_C_ops .partial_recv (tensor .detach (), 'use_calc_stream' ,
211212 use_calc_stream , 'ring_id' , ring_id ,
212213 'peer' , src_rank_in_group , 'num' ,
@@ -216,8 +217,11 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
216217 elif in_dygraph_mode ():
217218 group = paddle .distributed .collective ._get_default_group (
218219 ) if group is None else group
219- return group .process_group .recv_partial (tensor , src_rank_in_group ,
220+ task = group .process_group .recv_partial (tensor , src_rank_in_group ,
220221 nranks , rank_id )
222+ if use_calc_stream :
223+ task .wait ()
224+ return task
221225
222226
223227def recv_partial (tensor ,
@@ -238,7 +242,7 @@ def recv_partial(tensor,
238242 return _partial_recv_op (tensor , group , use_calc_stream , ring_id ,
239243 src_rank , nranks , rank_id )
240244 else :
241- if _in_legacy_dygraph ():
245+ if _in_legacy_dygraph () or use_calc_stream :
242246 recv_op = paddle .distributed .recv
243247 elif in_dygraph_mode ():
244248 recv_op = paddle .distributed .irecv
@@ -275,7 +279,11 @@ def allgather_partial(tensor,
275279 nranks , rank_id )
276280
277281
278- def _p2p_helper (tensor_send_next , tensor_send_prev , recv_prev , recv_next ):
282+ def _p2p_helper (tensor_send_next ,
283+ tensor_send_prev ,
284+ recv_prev ,
285+ recv_next ,
286+ sync_recv = True ):
279287 global _hcg
280288
281289 tensor_recv_prev = None
@@ -354,15 +362,15 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
354362 nranks = mp_degree ,
355363 rank_id = mp_rank ,
356364 group = _hcg .recv_prev_group ,
357- use_calc_stream = True ))
365+ use_calc_stream = sync_recv ))
358366 else :
359367 tasks .append (
360368 recv_partial (tensor_recv_prev ,
361369 src = 0 ,
362370 nranks = mp_degree ,
363371 rank_id = mp_rank ,
364372 group = _hcg .recv_prev_group ,
365- use_calc_stream = True ))
373+ use_calc_stream = sync_recv ))
366374
367375 if tensor_send_next is not None :
368376 if isinstance (tensor_send_next , tuple ):
@@ -394,7 +402,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
394402 nranks = mp_degree ,
395403 rank_id = mp_rank ,
396404 group = _hcg .recv_next_group ,
397- use_calc_stream = True ))
405+ use_calc_stream = sync_recv ))
398406
399407 else :
400408 tasks .append (
@@ -403,10 +411,10 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
403411 nranks = mp_degree ,
404412 rank_id = mp_rank ,
405413 group = _hcg .recv_next_group ,
406- use_calc_stream = True ))
414+ use_calc_stream = sync_recv ))
407415
408- if in_dygraph_mode ():
409- # wait isend/ irecv tasks in eager dygraph mode with new comm library
416+ if not sync_recv and in_dygraph_mode ():
417+ # wait irecv tasks in eager dygraph mode with new comm library
410418 for task in tasks :
411419 assert task is not None
412420 task .wait ()
@@ -443,7 +451,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
443451 return tensor_recv_prev , tensor_recv_next
444452
445453
446- def recv_forward (pp_first_stage ):
454+ def recv_forward (pp_first_stage , sync_recv = True ):
447455 if pp_first_stage :
448456 input_tensor = None
449457 else :
@@ -454,18 +462,20 @@ def recv_forward(pp_first_stage):
454462 input_tensor , _ = _p2p_helper (tensor_send_next = None ,
455463 tensor_send_prev = None ,
456464 recv_prev = True ,
457- recv_next = False )
465+ recv_next = False ,
466+ sync_recv = sync_recv )
458467 return input_tensor
459468
460469
461- def recv_backward (pp_last_stage ):
470+ def recv_backward (pp_last_stage , sync_recv = True ):
462471 if pp_last_stage :
463472 output_tensor_grad = None
464473 else :
465474 _ , output_tensor_grad = _p2p_helper (tensor_send_next = None ,
466475 tensor_send_prev = None ,
467476 recv_prev = False ,
468- recv_next = True )
477+ recv_next = True ,
478+ sync_recv = sync_recv )
469479 return output_tensor_grad
470480
471481
@@ -527,7 +537,8 @@ def send_forward_backward_recv_forward_backward(output_tensor,
527537 tensor_send_next = output_tensor ,
528538 tensor_send_prev = input_tensor_grad ,
529539 recv_prev = recv_prev ,
530- recv_next = recv_next )
540+ recv_next = recv_next ,
541+ sync_recv = False )
531542 return input_tensor , output_tensor_grad
532543
533544
@@ -544,7 +555,8 @@ def send_forward_recv_forward(output_tensor, recv_prev):
544555 input_tensor , _ = _p2p_helper (tensor_send_next = output_tensor ,
545556 tensor_send_prev = None ,
546557 recv_prev = recv_prev ,
547- recv_next = False )
558+ recv_next = False ,
559+ sync_recv = False )
548560
549561 return input_tensor
550562
@@ -553,5 +565,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next):
553565 _ , output_tensor_grad = _p2p_helper (tensor_send_next = None ,
554566 tensor_send_prev = input_tensor_grad ,
555567 recv_prev = False ,
556- recv_next = recv_next )
568+ recv_next = recv_next ,
569+ sync_recv = False )
557570 return output_tensor_grad
0 commit comments