Skip to content

Commit 20d7dbc

Browse files
authored
fix dual pp bug (#74158)
* fix dual pp bug * fix dual pp bug * fix bug
1 parent 03f6d59 commit 20d7dbc

File tree

1 file changed

+89
-15
lines changed

1 file changed

+89
-15
lines changed

python/paddle/distributed/fleet/meta_parallel/dualpipev.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def _forward_backward_compute(
260260
backward_phase: int,
261261
micro_datasets=None,
262262
combine_backward_event_to_wait=None,
263+
pass_pp_stream=False,
263264
) -> None:
264265
if self.forward_only:
265266
self._forward_compute(forward_phase, micro_datasets)
@@ -319,8 +320,12 @@ def _forward_backward_compute(
319320
backward_grads,
320321
self.scaler,
321322
combine_bw_event_to_wait=combine_backward_event_to_wait,
322-
pp_stream=self.pp_group.process_group.get_stream(
323-
paddle.framework._current_expected_place_()
323+
pp_stream=(
324+
self.pp_group.process_group.get_stream(
325+
paddle.framework._current_expected_place_()
326+
)
327+
if pass_pp_stream
328+
else None
324329
),
325330
)
326331
)
@@ -339,7 +344,9 @@ def _forward_backward_compute(
339344
backward_phase, backward_acc_id, input_grads=backward_input_grads
340345
)
341346

342-
def _commit_and_wait_comm(self) -> None:
347+
def _commit_and_wait_comm(
348+
self, p2p_overlap=False, use_outer_event_wait=False
349+
) -> None:
343350
common_forward_ops_num = (
344351
len(self.comm_forward_ops)
345352
if self.comm_forward_ops is not None
@@ -355,18 +362,26 @@ def _commit_and_wait_comm(self) -> None:
355362
paddle.device.current_stream().stream_base
356363
)
357364

358-
use_stream_wait_event = self._overlap_p2p_comm and deep_ep is not None
365+
use_stream_wait_event = (
366+
p2p_overlap and self._overlap_p2p_comm and deep_ep is not None
367+
)
359368

360369
pp_raw_stream = self.pp_group.process_group.get_stream(
361370
paddle.framework._current_expected_place_()
362371
)
372+
if use_outer_event_wait:
373+
self.pp_group.process_group.set_outer_wait(True)
363374

364375
if common_forward_ops_num > 0:
365376
fwd_reqs = batch_isend_irecv(self.comm_forward_ops)
366377

367378
if not use_stream_wait_event:
368379
for req in fwd_reqs:
369380
req.wait()
381+
382+
if use_outer_event_wait:
383+
self.pp_group.process_group.set_outer_wait(False)
384+
370385
if use_stream_wait_event:
371386
forward_event_to_wait = deep_ep.get_event_from_custom_stream(
372387
pp_raw_stream
@@ -524,29 +539,49 @@ def _forward_backward_pass(
524539
backward_phase: int,
525540
micro_datasets=None,
526541
recv0: bool = True,
542+
first_chunk=False,
543+
last_chunk=False,
544+
main_stage=False,
545+
last_stage_and_first_chunk=False,
527546
) -> None:
528547
if recv0:
529548
self._recv_forward(forward_phase)
530549
self._recv_backward(backward_phase)
531550

532-
use_outer_wait = (
533-
self._overlap_p2p_comm
551+
need_send_forward = not (
552+
self.is_pipeline_first_stage() and forward_phase == 1
553+
) or (self.is_pipeline_last_stage() and forward_phase == 0)
554+
need_send_backward = not (
555+
self.is_pipeline_first_stage() and backward_phase == 0
556+
) or (self.is_pipeline_last_stage() and backward_phase == 1)
557+
558+
use_outer_event_wait = (
559+
main_stage
560+
and not first_chunk
561+
and self._overlap_p2p_comm
534562
and deep_ep is not None
535-
and (len(self.comm_forward_ops) > 0)
563+
and (need_send_forward and need_send_backward)
536564
)
537565

538-
if use_outer_wait:
539-
self.pp_group.process_group.set_outer_wait(True)
566+
pass_pp_stream = (
567+
main_stage
568+
and not last_chunk
569+
and self._overlap_p2p_comm
570+
and deep_ep is not None
571+
and (need_send_forward and need_send_backward)
572+
and (not last_stage_and_first_chunk)
573+
)
540574

541-
combine_bw_wait_event = self._commit_and_wait_comm()
575+
combine_bw_wait_event = self._commit_and_wait_comm(
576+
not last_chunk, use_outer_event_wait
577+
)
542578

543-
if use_outer_wait:
544-
self.pp_group.process_group.set_outer_wait(False)
545579
self._forward_backward_compute(
546580
forward_phase,
547581
backward_phase,
548582
micro_datasets,
549583
combine_backward_event_to_wait=combine_bw_wait_event,
584+
pass_pp_stream=pass_pp_stream,
550585
)
551586

552587
self._send_forward(forward_phase)
@@ -663,7 +698,11 @@ def forward_backward_pipeline(
663698

664699
# Step 4 (Main step): nF0B1F1B0
665700
step_4 = self.accumulate_steps - num_ranks * 2 + rank + 1
701+
have_step5 = num_ranks - rank - 1 > 0
702+
# Update code to support send/recv overlap
703+
# Only support send/recv overlap in MainStep
666704
for i in range(step_4):
705+
is_last_chunk = i + 1 == step_4
667706
if i == 0:
668707
if self.is_pipeline_last_stage():
669708
# NOTE: We don't overlap these two passes to further reduce bubble size.
@@ -674,13 +713,48 @@ def forward_backward_pipeline(
674713
self._backward_pass(1, send=False)
675714
self._send_forward(0)
676715
self._send_backward(1)
716+
717+
self._forward_backward_pass(
718+
1,
719+
0,
720+
micro_datasets,
721+
first_chunk=True,
722+
last_chunk=is_last_chunk,
723+
main_stage=True,
724+
)
677725
else:
678726
self._forward_backward_pass(
679-
0, 1, micro_datasets, recv0=False
727+
0,
728+
1,
729+
micro_datasets,
730+
recv0=False,
731+
first_chunk=True,
732+
main_stage=True,
733+
)
734+
735+
self._forward_backward_pass(
736+
1,
737+
0,
738+
micro_datasets,
739+
last_chunk=is_last_chunk,
740+
main_stage=True,
680741
)
681742
else:
682-
self._forward_backward_pass(0, 1, micro_datasets)
683-
self._forward_backward_pass(1, 0, micro_datasets)
743+
744+
self._forward_backward_pass(
745+
0,
746+
1,
747+
micro_datasets,
748+
main_stage=True,
749+
last_stage_and_first_chunk=self.is_pipeline_last_stage(),
750+
)
751+
self._forward_backward_pass(
752+
1,
753+
0,
754+
micro_datasets,
755+
last_chunk=is_last_chunk,
756+
main_stage=True,
757+
)
684758

685759
# Step 5: nB1F1B0
686760
step_5 = num_ranks - rank - 1

0 commit comments

Comments
 (0)