@@ -260,6 +260,7 @@ def _forward_backward_compute(
260260 backward_phase : int ,
261261 micro_datasets = None ,
262262 combine_backward_event_to_wait = None ,
263+ pass_pp_stream = False ,
263264 ) -> None :
264265 if self .forward_only :
265266 self ._forward_compute (forward_phase , micro_datasets )
@@ -319,8 +320,12 @@ def _forward_backward_compute(
319320 backward_grads ,
320321 self .scaler ,
321322 combine_bw_event_to_wait = combine_backward_event_to_wait ,
322- pp_stream = self .pp_group .process_group .get_stream (
323- paddle .framework ._current_expected_place_ ()
323+ pp_stream = (
324+ self .pp_group .process_group .get_stream (
325+ paddle .framework ._current_expected_place_ ()
326+ )
327+ if pass_pp_stream
328+ else None
324329 ),
325330 )
326331 )
@@ -339,7 +344,9 @@ def _forward_backward_compute(
339344 backward_phase , backward_acc_id , input_grads = backward_input_grads
340345 )
341346
342- def _commit_and_wait_comm (self ) -> None :
347+ def _commit_and_wait_comm (
348+ self , p2p_overlap = False , use_outer_event_wait = False
349+ ) -> None :
343350 common_forward_ops_num = (
344351 len (self .comm_forward_ops )
345352 if self .comm_forward_ops is not None
@@ -355,18 +362,26 @@ def _commit_and_wait_comm(self) -> None:
355362 paddle .device .current_stream ().stream_base
356363 )
357364
358- use_stream_wait_event = self ._overlap_p2p_comm and deep_ep is not None
365+ use_stream_wait_event = (
366+ p2p_overlap and self ._overlap_p2p_comm and deep_ep is not None
367+ )
359368
360369 pp_raw_stream = self .pp_group .process_group .get_stream (
361370 paddle .framework ._current_expected_place_ ()
362371 )
372+ if use_outer_event_wait :
373+ self .pp_group .process_group .set_outer_wait (True )
363374
364375 if common_forward_ops_num > 0 :
365376 fwd_reqs = batch_isend_irecv (self .comm_forward_ops )
366377
367378 if not use_stream_wait_event :
368379 for req in fwd_reqs :
369380 req .wait ()
381+
382+ if use_outer_event_wait :
383+ self .pp_group .process_group .set_outer_wait (False )
384+
370385 if use_stream_wait_event :
371386 forward_event_to_wait = deep_ep .get_event_from_custom_stream (
372387 pp_raw_stream
@@ -524,29 +539,49 @@ def _forward_backward_pass(
524539 backward_phase : int ,
525540 micro_datasets = None ,
526541 recv0 : bool = True ,
542+ first_chunk = False ,
543+ last_chunk = False ,
544+ main_stage = False ,
545+ last_stage_and_first_chunk = False ,
527546 ) -> None :
528547 if recv0 :
529548 self ._recv_forward (forward_phase )
530549 self ._recv_backward (backward_phase )
531550
532- use_outer_wait = (
533- self ._overlap_p2p_comm
551+ need_send_forward = not (
552+ self .is_pipeline_first_stage () and forward_phase == 1
553+ ) or (self .is_pipeline_last_stage () and forward_phase == 0 )
554+ need_send_backward = not (
555+ self .is_pipeline_first_stage () and backward_phase == 0
556+ ) or (self .is_pipeline_last_stage () and backward_phase == 1 )
557+
558+ use_outer_event_wait = (
559+ main_stage
560+ and not first_chunk
561+ and self ._overlap_p2p_comm
534562 and deep_ep is not None
535- and (len ( self . comm_forward_ops ) > 0 )
563+ and (need_send_forward and need_send_backward )
536564 )
537565
538- if use_outer_wait :
539- self .pp_group .process_group .set_outer_wait (True )
566+ pass_pp_stream = (
567+ main_stage
568+ and not last_chunk
569+ and self ._overlap_p2p_comm
570+ and deep_ep is not None
571+ and (need_send_forward and need_send_backward )
572+ and (not last_stage_and_first_chunk )
573+ )
540574
541- combine_bw_wait_event = self ._commit_and_wait_comm ()
575+ combine_bw_wait_event = self ._commit_and_wait_comm (
576+ not last_chunk , use_outer_event_wait
577+ )
542578
543- if use_outer_wait :
544- self .pp_group .process_group .set_outer_wait (False )
545579 self ._forward_backward_compute (
546580 forward_phase ,
547581 backward_phase ,
548582 micro_datasets ,
549583 combine_backward_event_to_wait = combine_bw_wait_event ,
584+ pass_pp_stream = pass_pp_stream ,
550585 )
551586
552587 self ._send_forward (forward_phase )
@@ -663,7 +698,11 @@ def forward_backward_pipeline(
663698
664699 # Step 4 (Main step): nF0B1F1B0
665700 step_4 = self .accumulate_steps - num_ranks * 2 + rank + 1
701+ have_step5 = num_ranks - rank - 1 > 0
702+ # Update code to support send/recv overlap
703+ # Only support send/recv overlap in MainStep
666704 for i in range (step_4 ):
705+ is_last_chunk = i + 1 == step_4
667706 if i == 0 :
668707 if self .is_pipeline_last_stage ():
669708 # NOTE: We don't overlap these two passes to further reduce bubble size.
@@ -674,13 +713,48 @@ def forward_backward_pipeline(
674713 self ._backward_pass (1 , send = False )
675714 self ._send_forward (0 )
676715 self ._send_backward (1 )
716+
717+ self ._forward_backward_pass (
718+ 1 ,
719+ 0 ,
720+ micro_datasets ,
721+ first_chunk = True ,
722+ last_chunk = is_last_chunk ,
723+ main_stage = True ,
724+ )
677725 else :
678726 self ._forward_backward_pass (
679- 0 , 1 , micro_datasets , recv0 = False
727+ 0 ,
728+ 1 ,
729+ micro_datasets ,
730+ recv0 = False ,
731+ first_chunk = True ,
732+ main_stage = True ,
733+ )
734+
735+ self ._forward_backward_pass (
736+ 1 ,
737+ 0 ,
738+ micro_datasets ,
739+ last_chunk = is_last_chunk ,
740+ main_stage = True ,
680741 )
681742 else :
682- self ._forward_backward_pass (0 , 1 , micro_datasets )
683- self ._forward_backward_pass (1 , 0 , micro_datasets )
743+
744+ self ._forward_backward_pass (
745+ 0 ,
746+ 1 ,
747+ micro_datasets ,
748+ main_stage = True ,
749+ last_stage_and_first_chunk = self .is_pipeline_last_stage (),
750+ )
751+ self ._forward_backward_pass (
752+ 1 ,
753+ 0 ,
754+ micro_datasets ,
755+ last_chunk = is_last_chunk ,
756+ main_stage = True ,
757+ )
684758
685759 # Step 5: nB1F1B0
686760 step_5 = num_ranks - rank - 1
0 commit comments