Skip to content

Commit b8e4bc3

Browse files
SylarTiaNIIYour Name
authored andcommitted
[Distributed] fix eval batch & non-compute_loss in pipeline (#73479)
1 parent 42153e0 commit b8e4bc3

File tree

1 file changed

+79
-36
lines changed

1 file changed

+79
-36
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy):
432432
self.loss_fn_idx = 0
433433

434434
self._compute_loss = True
435+
self._return_host_tensor = False
435436
self.callbacks = pipeline_parallel_callbacks_
436437

437438
logger.info(
@@ -1026,13 +1027,16 @@ def train_batch(
10261027

10271028
return train_loss
10281029

1029-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1030+
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False):
10301031
self.user_hooks_enabled = False
10311032
# reset the virtual pp rank for each run
10321033
self.set_virtual_pipeline_rank(0)
10331034

10341035
self._layers.eval()
1036+
origin_compute_loss = self._compute_loss
10351037
self._compute_loss = compute_loss
1038+
origin_return_host_tensor = self._return_host_tensor
1039+
self._return_host_tensor = return_host_tensor
10361040

10371041
# store data id for micro_batch
10381042
self.micro_batch_id = 0
@@ -1072,6 +1076,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10721076
skip_check_meta=True,
10731077
batch_p2p_comm=self._use_batch_p2p_comm,
10741078
)
1079+
self._offload_tensors(output_tensor)
10751080

10761081
input_buffers.append(input_tensor)
10771082
output_buffers.append(output_tensor)
@@ -1094,6 +1099,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10941099
skip_check_meta=True,
10951100
batch_p2p_comm=self._use_batch_p2p_comm,
10961101
)
1102+
self._offload_tensors(output_tensor)
10971103

10981104
input_buffers.append(input_tensor)
10991105
output_buffers.append(output_tensor)
@@ -1105,11 +1111,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
11051111
)
11061112

11071113
if self._compute_loss:
1108-
self.train_loss = self._broadcast_final_loss()
1114+
train_loss = self._broadcast_final_loss()
11091115
else:
1110-
self.train_loss = output_buffers
1116+
train_loss = output_buffers
11111117

1112-
return self.train_loss
1118+
self._compute_loss = origin_compute_loss
1119+
self._return_host_tensor = origin_return_host_tensor
1120+
return train_loss
11131121

11141122
def _maybe_loss_compute(
11151123
self, output_tensor, micro_dataset, overlap_schedule_mode=False
@@ -1424,6 +1432,17 @@ def _optimizer_step(self):
14241432
if self.lr_scheduler:
14251433
self.lr_scheduler.step()
14261434

1435+
def _offload_tensors(self, output_tensor):
1436+
if not self._return_host_tensor:
1437+
return
1438+
if isinstance(output_tensor, (tuple, list)):
1439+
for t in output_tensor:
1440+
host_tensor = t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
1441+
host_tensor._share_buffer_to(t)
1442+
else:
1443+
host_tensor = output_tensor.pin_memory() if hasattr(output_tensor, "pin_memory") else output_tensor.cpu()
1444+
host_tensor._share_buffer_to(output_tensor)
1445+
14271446
def _release_output(self, output):
14281447
def can_free(t):
14291448
return (
@@ -1694,10 +1713,12 @@ def _get_forward_input(self, virtual_pp_rank):
16941713
assert hasattr(self, 'output_tensors')
16951714
if not self._forward_only:
16961715
assert hasattr(self, 'output_tensor_grads')
1697-
assert len(self.input_tensors[virtual_pp_rank]) == (
1698-
len(self.output_tensors[virtual_pp_rank]) + 1
1699-
)
1700-
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1716+
assert len(self.input_tensors[virtual_pp_rank]) == (
1717+
len(self.output_tensors[virtual_pp_rank]) + 1
1718+
)
1719+
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1720+
else:
1721+
input_tensor = self.input_tensors[virtual_pp_rank].pop()
17011722
return input_tensor
17021723

17031724
def _store_forward_outputs(
@@ -1712,11 +1733,17 @@ def _store_forward_outputs(
17121733
self.schedule_chunks[virtual_pp_rank].append(schedule_chunk)
17131734
if self.is_pipeline_last_stage():
17141735
self.loss_fn_chunks.append(loss_fn_node)
1715-
1716-
if self._forward_only:
1736+
if self._forward_only:
1737+
# no need to store tensor for backward
1738+
if self._compute_loss:
1739+
self.output_tensors[virtual_pp_rank].pop()
1740+
# save output_tensors for return value of eval batch
1741+
else:
1742+
self._offload_tensors(output_tensor)
1743+
else:
17171744
# no need to store tensor for backward
1718-
self.input_tensors[virtual_pp_rank].pop()
1719-
self.output_tensors[virtual_pp_rank].pop()
1745+
if self._forward_only:
1746+
self.output_tensors[virtual_pp_rank].pop()
17201747

17211748
def _forward_step_helper(
17221749
self,
@@ -2022,7 +2049,7 @@ def forward_backward_pipeline(
20222049
# this strategy is inspired by:
20232050
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
20242051
if not compute_loss:
2025-
assert not forward_only, (
2052+
assert forward_only, (
20262053
"compute_loss can only be set to False when forward_only is set to True"
20272054
)
20282055

@@ -2669,7 +2696,7 @@ def backward_async_comm(
26692696

26702697
# no steady steps, which only occurs when accumulate_step == num_stage
26712698
if not steady_steps:
2672-
output_tensor_grad = p2p.recv_backward(
2699+
output_tensor_grad = self._p2p_helper.recv_backward(
26732700
self.is_pipeline_last_stage(),
26742701
batch_p2p_comm=self._use_batch_p2p_comm,
26752702
)
@@ -2800,12 +2827,12 @@ def backward_async_comm(
28002827
if self._enable_timer:
28012828
self.timers("broadcast_final_loss").start()
28022829
with paddle.amp.auto_cast(enable=False):
2803-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
2830+
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
28042831
if self._enable_timer:
28052832
self.timers("broadcast_final_loss").stop()
28062833
else:
2807-
# else just return all intermediate output tensor for all micro steps
2808-
train_loss = self.output_tensors
2834+
# else just return logits without loss func calc
2835+
train_loss_or_logits = self.output_tensors.pop()
28092836

28102837
if self._clear_every_step_cache:
28112838
self._p2p_helper.clear_meta_cache()
@@ -2823,7 +2850,7 @@ def backward_async_comm(
28232850
), "p2p dynamic_cnt should equal to send_recv_meta_list"
28242851
self._p2p_helper._dynamic_cnt = 0
28252852

2826-
return train_loss
2853+
return train_loss_or_logits
28272854

28282855
def train_batch(
28292856
self,
@@ -2854,13 +2881,16 @@ def train_batch(
28542881

28552882
return train_loss
28562883

2857-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
2884+
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False):
28582885
self.user_hooks_enabled = False
28592886
# reset the virtual pp rank for each run
28602887
self.set_virtual_pipeline_rank(0)
28612888

28622889
self._layers.eval()
2890+
origin_compute_loss = self._compute_loss
28632891
self._compute_loss = compute_loss
2892+
origin_return_host_tensor = self._return_host_tensor
2893+
self._return_host_tensor = return_host_tensor
28642894

28652895
# check loss_fn_idx is valid and loss_fn exists
28662896
assert (
@@ -2869,7 +2899,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
28692899
), f"loss function {loss_fn_idx} should exist to compute loss"
28702900
self.loss_fn_idx = loss_fn_idx
28712901

2872-
return self.forward_backward_pipeline(data, None, forward_only=True)
2902+
train_loss_or_logits = self.forward_backward_pipeline(data, None, forward_only=True, compute_loss=compute_loss)
2903+
self._init_buffers()
2904+
self._compute_loss = origin_compute_loss
2905+
self._return_host_tensor = origin_return_host_tensor
2906+
return train_loss_or_logits
28732907

28742908
def get_static_scheduler(self):
28752909
return self.forward_backward_pipeline(
@@ -2959,9 +2993,9 @@ def forward_backward_pipeline(
29592993
if self.processed_steps < g_profile_pipeline_details_steps:
29602994
get_sync_logger().info("start forward_backward_pipeline")
29612995
if not compute_loss:
2962-
assert not forward_only, (
2963-
"compute_loss can only be set to False when forward_only is set to True"
2964-
)
2996+
assert (
2997+
forward_only
2998+
), "compute_loss can only be set to False when forward_only is set to True"
29652999

29663000
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
29673001
assert self._using_cache, (
@@ -2977,10 +3011,15 @@ def forward_backward_pipeline(
29773011

29783012
assert (
29793013
self.accumulate_steps == self.num_stages
3014+
<<<<<<< HEAD
29803015
or self.accumulate_steps % self.num_stages != 0
29813016
), (
29823017
f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0"
29833018
)
3019+
=======
3020+
or self.accumulate_steps % self.num_stages == 0
3021+
), f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0"
3022+
>>>>>>> 4c472714c0 ([Distributed] fix eval batch & non-compute_loss in pipeline (#73479))
29843023

29853024
self._backward_step_count = 0
29863025
skip_steps = self.accumulate_steps - self.num_stages
@@ -3108,12 +3147,12 @@ def forward_backward_pipeline(
31083147
if self._enable_timer:
31093148
self.timers("broadcast_final_loss").start()
31103149
with paddle.amp.auto_cast(enable=False):
3111-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
3150+
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
31123151
if self._enable_timer:
31133152
self.timers("broadcast_final_loss").stop()
31143153
else:
3115-
# else just return all intermediate output tensor for all micro steps
3116-
train_loss = self.output_tensors
3154+
# else just return logits without loss func calc
3155+
train_loss_or_logits = self.output_tensors.pop()
31173156

31183157
if self._clear_every_step_cache:
31193158
self._p2p_helper.clear_meta_cache()
@@ -3124,7 +3163,7 @@ def forward_backward_pipeline(
31243163
get_sync_logger().info("end forward_backward_pipeline")
31253164
self.processed_steps += 1
31263165
self._check_user_hooks_status_at_step_end()
3127-
return train_loss
3166+
return train_loss_or_logits
31283167

31293168

31303169
class OffloadQueue(queue.Queue):
@@ -3187,12 +3226,12 @@ def forward_backward_pipeline(
31873226
):
31883227
self._reset_user_hooks_status()
31893228
if not compute_loss:
3190-
assert not forward_only, (
3191-
"compute_loss can only be set to False when forward_only is set to True"
3192-
)
3193-
assert self._using_cache, (
3194-
"cache should be enabled for pipeline with interleave"
3195-
)
3229+
assert (
3230+
forward_only
3231+
), "compute_loss can only be set to False when forward_only is set to True"
3232+
assert (
3233+
self._using_cache
3234+
), "cache should be enabled for pipeline with interleave"
31963235
self.user_hooks_enabled = not forward_only
31973236
if forward_only:
31983237
return super().forward_backward_pipeline(
@@ -3462,12 +3501,12 @@ def forward_backward_pipeline(
34623501
if self._enable_timer:
34633502
self.timers("broadcast_final_loss").start()
34643503
with paddle.amp.auto_cast(enable=False):
3465-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
3504+
train_loss_or_logits = self._broadcast_final_loss(return_micro_batch_loss)
34663505
if self._enable_timer:
34673506
self.timers("broadcast_final_loss").stop()
34683507
else:
3469-
# else just return all intermediate output tensor for all micro steps
3470-
train_loss = self.output_tensors
3508+
# else just return logits without loss func calc
3509+
train_loss_or_logits = self.output_tensors.pop()
34713510

34723511
if self._clear_every_step_cache:
34733512
self._p2p_helper.clear_meta_cache()
@@ -3478,6 +3517,7 @@ def forward_backward_pipeline(
34783517
get_sync_logger().info("end forward_backward_pipeline")
34793518
self.processed_steps += 1
34803519
self._check_user_hooks_status_at_step_end()
3520+
<<<<<<< HEAD
34813521
return train_loss
34823522

34833523

@@ -3531,3 +3571,6 @@ def convert_tensor_tuple_to_dict(input_tensor_tuple):
35313571
input_tensor_dict[key] = tensor
35323572
delattr(tensor, "key")
35333573
return input_tensor_dict
3574+
=======
3575+
return train_loss_or_logits
3576+
>>>>>>> 4c472714c0 ([Distributed] fix eval batch & non-compute_loss in pipeline (#73479))

0 commit comments

Comments
 (0)