Skip to content

Commit d950f5a

Browse files
tianhaodongbdYour Name
authored andcommitted
[Distributed] fix eval batch && codestyle in PipelineParallel (#73978)
1 parent b8e4bc3 commit d950f5a

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,9 @@ def train_batch(
10271027

10281028
return train_loss
10291029

1030-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False):
1030+
def eval_batch(
1031+
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
1032+
):
10311033
self.user_hooks_enabled = False
10321034
# reset the virtual pp rank for each run
10331035
self.set_virtual_pipeline_rank(0)
@@ -1055,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
10551057
startup_steps = min(startup_steps, self.accumulate_steps)
10561058
steady_steps = self.accumulate_steps - startup_steps
10571059

1058-
input_buffers = []
10591060
output_buffers = []
10601061

10611062
# convert to micro dataset
@@ -1076,9 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
10761077
skip_check_meta=True,
10771078
batch_p2p_comm=self._use_batch_p2p_comm,
10781079
)
1079-
self._offload_tensors(output_tensor)
1080+
if not self.is_pipeline_last_stage():
1081+
self._release_output(output_tensor)
1082+
else:
1083+
self._offload_tensors(output_tensor)
10801084

1081-
input_buffers.append(input_tensor)
10821085
output_buffers.append(output_tensor)
10831086

10841087
if steady_steps > 0:
@@ -1099,9 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
10991102
skip_check_meta=True,
11001103
batch_p2p_comm=self._use_batch_p2p_comm,
11011104
)
1102-
self._offload_tensors(output_tensor)
1105+
if not self.is_pipeline_last_stage():
1106+
self._release_output(output_tensor)
1107+
else:
1108+
self._offload_tensors(output_tensor)
11031109

1104-
input_buffers.append(input_tensor)
11051110
output_buffers.append(output_tensor)
11061111

11071112
if not last_iter:
@@ -1437,10 +1442,16 @@ def _offload_tensors(self, output_tensor):
14371442
return
14381443
if isinstance(output_tensor, (tuple, list)):
14391444
for t in output_tensor:
1440-
host_tensor = t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
1445+
host_tensor = (
1446+
t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
1447+
)
14411448
host_tensor._share_buffer_to(t)
14421449
else:
1443-
host_tensor = output_tensor.pin_memory() if hasattr(output_tensor, "pin_memory") else output_tensor.cpu()
1450+
host_tensor = (
1451+
output_tensor.pin_memory()
1452+
if hasattr(output_tensor, "pin_memory")
1453+
else output_tensor.cpu()
1454+
)
14441455
host_tensor._share_buffer_to(output_tensor)
14451456

14461457
def _release_output(self, output):
@@ -2827,7 +2838,9 @@ def backward_async_comm(
28272838
if self._enable_timer:
28282839
self.timers("broadcast_final_loss").start()
28292840
with paddle.amp.auto_cast(enable=False):
2830-
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
2841+
train_loss_or_logits = self._broadcast_final_loss(
2842+
return_micro_batch_loss
2843+
)
28312844
if self._enable_timer:
28322845
self.timers("broadcast_final_loss").stop()
28332846
else:
@@ -2881,7 +2894,9 @@ def train_batch(
28812894

28822895
return train_loss
28832896

2884-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False):
2897+
def eval_batch(
2898+
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
2899+
):
28852900
self.user_hooks_enabled = False
28862901
# reset the virtual pp rank for each run
28872902
self.set_virtual_pipeline_rank(0)
@@ -2899,7 +2914,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
28992914
), f"loss function {loss_fn_idx} should exist to compute loss"
29002915
self.loss_fn_idx = loss_fn_idx
29012916

2902-
train_loss_or_logits = self.forward_backward_pipeline(data, None, forward_only=True, compute_loss=compute_loss)
2917+
train_loss_or_logits = self.forward_backward_pipeline(
2918+
data, None, forward_only=True, compute_loss=compute_loss
2919+
)
29032920
self._init_buffers()
29042921
self._compute_loss = origin_compute_loss
29052922
self._return_host_tensor = origin_return_host_tensor
@@ -2993,9 +3010,9 @@ def forward_backward_pipeline(
29933010
if self.processed_steps < g_profile_pipeline_details_steps:
29943011
get_sync_logger().info("start forward_backward_pipeline")
29953012
if not compute_loss:
2996-
assert (
2997-
forward_only
2998-
), "compute_loss can only be set to False when forward_only is set to True"
3013+
assert forward_only, (
3014+
"compute_loss can only be set to False when forward_only is set to True"
3015+
)
29993016

30003017
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
30013018
assert self._using_cache, (
@@ -3011,15 +3028,10 @@ def forward_backward_pipeline(
30113028

30123029
assert (
30133030
self.accumulate_steps == self.num_stages
3014-
<<<<<<< HEAD
3015-
or self.accumulate_steps % self.num_stages != 0
3031+
or self.accumulate_steps % self.num_stages == 0
30163032
), (
30173033
f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0"
30183034
)
3019-
=======
3020-
or self.accumulate_steps % self.num_stages == 0
3021-
), f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0"
3022-
>>>>>>> 4c472714c0 ([Distributed] fix eval batch & non-compute_loss in pipeline (#73479))
30233035

30243036
self._backward_step_count = 0
30253037
skip_steps = self.accumulate_steps - self.num_stages
@@ -3147,7 +3159,9 @@ def forward_backward_pipeline(
31473159
if self._enable_timer:
31483160
self.timers("broadcast_final_loss").start()
31493161
with paddle.amp.auto_cast(enable=False):
3150-
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
3162+
train_loss_or_logits = self._broadcast_final_loss(
3163+
return_micro_batch_loss
3164+
)
31513165
if self._enable_timer:
31523166
self.timers("broadcast_final_loss").stop()
31533167
else:
@@ -3226,12 +3240,12 @@ def forward_backward_pipeline(
32263240
):
32273241
self._reset_user_hooks_status()
32283242
if not compute_loss:
3229-
assert (
3230-
forward_only
3231-
), "compute_loss can only be set to False when forward_only is set to True"
3232-
assert (
3233-
self._using_cache
3234-
), "cache should be enabled for pipeline with interleave"
3243+
assert forward_only, (
3244+
"compute_loss can only be set to False when forward_only is set to True"
3245+
)
3246+
assert self._using_cache, (
3247+
"cache should be enabled for pipeline with interleave"
3248+
)
32353249
self.user_hooks_enabled = not forward_only
32363250
if forward_only:
32373251
return super().forward_backward_pipeline(
@@ -3501,7 +3515,9 @@ def forward_backward_pipeline(
35013515
if self._enable_timer:
35023516
self.timers("broadcast_final_loss").start()
35033517
with paddle.amp.auto_cast(enable=False):
3504-
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
3518+
train_loss_or_logits = self._broadcast_final_loss(
3519+
return_micro_batch_loss
3520+
)
35053521
if self._enable_timer:
35063522
self.timers("broadcast_final_loss").stop()
35073523
else:
@@ -3517,8 +3533,7 @@ def forward_backward_pipeline(
35173533
get_sync_logger().info("end forward_backward_pipeline")
35183534
self.processed_steps += 1
35193535
self._check_user_hooks_status_at_step_end()
3520-
<<<<<<< HEAD
3521-
return train_loss
3536+
return train_loss_or_logits
35223537

35233538

35243539
def tuple_to_dict_helper(input_tensor):
@@ -3571,6 +3586,3 @@ def convert_tensor_tuple_to_dict(input_tensor_tuple):
35713586
input_tensor_dict[key] = tensor
35723587
delattr(tensor, "key")
35733588
return input_tensor_dict
3574-
=======
3575-
return train_loss_or_logits
3576-
>>>>>>> 4c472714c0 ([Distributed] fix eval batch & non-compute_loss in pipeline (#73479))

0 commit comments

Comments
 (0)