@@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy):
432432 self .loss_fn_idx = 0
433433
434434 self ._compute_loss = True
435+ self ._return_host_tensor = False
435436 self .callbacks = pipeline_parallel_callbacks_
436437
437438 logger .info (
@@ -1026,13 +1027,16 @@ def train_batch(
10261027
10271028 return train_loss
10281029
1029- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
1030+ def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False ):
10301031 self .user_hooks_enabled = False
10311032 # reset the virtual pp rank for each run
10321033 self .set_virtual_pipeline_rank (0 )
10331034
10341035 self ._layers .eval ()
1036+ origin_compute_loss = self ._compute_loss
10351037 self ._compute_loss = compute_loss
1038+ origin_return_host_tensor = self ._return_host_tensor
1039+ self ._return_host_tensor = return_host_tensor
10361040
10371041 # store data id for micro_batch
10381042 self .micro_batch_id = 0
@@ -1072,6 +1076,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10721076 skip_check_meta = True ,
10731077 batch_p2p_comm = self ._use_batch_p2p_comm ,
10741078 )
1079+ self ._offload_tensors (output_tensor )
10751080
10761081 input_buffers .append (input_tensor )
10771082 output_buffers .append (output_tensor )
@@ -1094,6 +1099,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10941099 skip_check_meta = True ,
10951100 batch_p2p_comm = self ._use_batch_p2p_comm ,
10961101 )
1102+ self ._offload_tensors (output_tensor )
10971103
10981104 input_buffers .append (input_tensor )
10991105 output_buffers .append (output_tensor )
@@ -1105,11 +1111,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
11051111 )
11061112
11071113 if self ._compute_loss :
1108- self . train_loss = self ._broadcast_final_loss ()
1114+ train_loss = self ._broadcast_final_loss ()
11091115 else :
1110- self . train_loss = output_buffers
1116+ train_loss = output_buffers
11111117
1112- return self .train_loss
1118+ self ._compute_loss = origin_compute_loss
1119+ self ._return_host_tensor = origin_return_host_tensor
1120+ return train_loss
11131121
11141122 def _maybe_loss_compute (
11151123 self , output_tensor , micro_dataset , overlap_schedule_mode = False
@@ -1424,6 +1432,17 @@ def _optimizer_step(self):
14241432 if self .lr_scheduler :
14251433 self .lr_scheduler .step ()
14261434
1435+ def _offload_tensors (self , output_tensor ):
1436+ if not self ._return_host_tensor :
1437+ return
1438+ if isinstance (output_tensor , (tuple , list )):
1439+ for t in output_tensor :
1440+ host_tensor = t .pin_memory () if hasattr (t , "pin_memory" ) else t .cpu ()
1441+ host_tensor ._share_buffer_to (t )
1442+ else :
1443+ host_tensor = output_tensor .pin_memory () if hasattr (output_tensor , "pin_memory" ) else output_tensor .cpu ()
1444+ host_tensor ._share_buffer_to (output_tensor )
1445+
14271446 def _release_output (self , output ):
14281447 def can_free (t ):
14291448 return (
@@ -1694,10 +1713,12 @@ def _get_forward_input(self, virtual_pp_rank):
16941713 assert hasattr (self , 'output_tensors' )
16951714 if not self ._forward_only :
16961715 assert hasattr (self , 'output_tensor_grads' )
1697- assert len (self .input_tensors [virtual_pp_rank ]) == (
1698- len (self .output_tensors [virtual_pp_rank ]) + 1
1699- )
1700- input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1716+ assert len (self .input_tensors [virtual_pp_rank ]) == (
1717+ len (self .output_tensors [virtual_pp_rank ]) + 1
1718+ )
1719+ input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1720+ else :
1721+ input_tensor = self .input_tensors [virtual_pp_rank ].pop ()
17011722 return input_tensor
17021723
17031724 def _store_forward_outputs (
@@ -1712,11 +1733,17 @@ def _store_forward_outputs(
17121733 self .schedule_chunks [virtual_pp_rank ].append (schedule_chunk )
17131734 if self .is_pipeline_last_stage ():
17141735 self .loss_fn_chunks .append (loss_fn_node )
1715-
1716- if self ._forward_only :
1736+ if self ._forward_only :
1737+ # no need to store tensor for backward
1738+ if self ._compute_loss :
1739+ self .output_tensors [virtual_pp_rank ].pop ()
1740+ # save output_tensors for return value of eval batch
1741+ else :
1742+ self ._offload_tensors (output_tensor )
1743+ else :
17171744 # no need to store tensor for backward
1718- self .input_tensors [ virtual_pp_rank ]. pop ()
1719- self .output_tensors [virtual_pp_rank ].pop ()
1745+ if self ._forward_only :
1746+ self .output_tensors [virtual_pp_rank ].pop ()
17201747
17211748 def _forward_step_helper (
17221749 self ,
@@ -2022,7 +2049,7 @@ def forward_backward_pipeline(
20222049 # this strategy is inspired by:
20232050 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
20242051 if not compute_loss :
2025- assert not forward_only , (
2052+ assert forward_only , (
20262053 "compute_loss can only be set to False when forward_only is set to True"
20272054 )
20282055
@@ -2669,7 +2696,7 @@ def backward_async_comm(
26692696
26702697 # no steady steps, which only occurs when accumulate_step == num_stage
26712698 if not steady_steps :
2672- output_tensor_grad = p2p .recv_backward (
2699+ output_tensor_grad = self . _p2p_helper .recv_backward (
26732700 self .is_pipeline_last_stage (),
26742701 batch_p2p_comm = self ._use_batch_p2p_comm ,
26752702 )
@@ -2800,12 +2827,12 @@ def backward_async_comm(
28002827 if self ._enable_timer :
28012828 self .timers ("broadcast_final_loss" ).start ()
28022829 with paddle .amp .auto_cast (enable = False ):
2803- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
2830+ train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
28042831 if self ._enable_timer :
28052832 self .timers ("broadcast_final_loss" ).stop ()
28062833 else :
2807- # else just return all intermediate output tensor for all micro steps
2808- train_loss = self .output_tensors
2834+ # else just return logits without loss func calc
2835+ train_loss_or_logits = self .output_tensors . pop ()
28092836
28102837 if self ._clear_every_step_cache :
28112838 self ._p2p_helper .clear_meta_cache ()
@@ -2823,7 +2850,7 @@ def backward_async_comm(
28232850 ), "p2p dynamic_cnt should equal to send_recv_meta_list"
28242851 self ._p2p_helper ._dynamic_cnt = 0
28252852
2826- return train_loss
2853+ return train_loss_or_logits
28272854
28282855 def train_batch (
28292856 self ,
@@ -2854,13 +2881,16 @@ def train_batch(
28542881
28552882 return train_loss
28562883
2857- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
2884+ def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False ):
28582885 self .user_hooks_enabled = False
28592886 # reset the virtual pp rank for each run
28602887 self .set_virtual_pipeline_rank (0 )
28612888
28622889 self ._layers .eval ()
2890+ origin_compute_loss = self ._compute_loss
28632891 self ._compute_loss = compute_loss
2892+ origin_return_host_tensor = self ._return_host_tensor
2893+ self ._return_host_tensor = return_host_tensor
28642894
28652895 # check loss_fn_idx is valid and loss_fn exists
28662896 assert (
@@ -2869,7 +2899,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
28692899 ), f"loss function { loss_fn_idx } should exist to compute loss"
28702900 self .loss_fn_idx = loss_fn_idx
28712901
2872- return self .forward_backward_pipeline (data , None , forward_only = True )
2902+ train_loss_or_logits = self .forward_backward_pipeline (data , None , forward_only = True , compute_loss = compute_loss )
2903+ self ._init_buffers ()
2904+ self ._compute_loss = origin_compute_loss
2905+ self ._return_host_tensor = origin_return_host_tensor
2906+ return train_loss_or_logits
28732907
28742908 def get_static_scheduler (self ):
28752909 return self .forward_backward_pipeline (
@@ -2959,9 +2993,9 @@ def forward_backward_pipeline(
29592993 if self .processed_steps < g_profile_pipeline_details_steps :
29602994 get_sync_logger ().info ("start forward_backward_pipeline" )
29612995 if not compute_loss :
2962- assert not forward_only , (
2963- "compute_loss can only be set to False when forward_only is set to True"
2964- )
2996+ assert (
2997+ forward_only
2998+ ), "compute_loss can only be set to False when forward_only is set to True"
29652999
29663000 # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
29673001 assert self ._using_cache , (
@@ -2977,10 +3011,15 @@ def forward_backward_pipeline(
29773011
29783012 assert (
29793013 self .accumulate_steps == self .num_stages
3014+ << < << << HEAD
29803015 or self .accumulate_steps % self .num_stages != 0
29813016 ), (
29823017 f"accumulate_steps({ self .accumulate_steps } ) and num_stages({ self .num_stages } ) should be a multiple or accumulate_steps % num_stages == 0"
29833018 )
3019+ == == == =
3020+ or self .accumulate_steps % self .num_stages == 0
3021+ ), f"accumulate_steps({ self .accumulate_steps } ) and num_stages({ self .num_stages } ) should be a multiple or accumulate_steps % num_stages == 0"
3022+ >> >> >> > 4 c472714c0 ([Distributed ] fix eval batch & non - compute_loss in pipeline (#73479))
29843023
29853024 self ._backward_step_count = 0
29863025 skip_steps = self .accumulate_steps - self .num_stages
@@ -3108,12 +3147,12 @@ def forward_backward_pipeline(
31083147 if self ._enable_timer :
31093148 self .timers ("broadcast_final_loss" ).start ()
31103149 with paddle .amp .auto_cast (enable = False ):
3111- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3150+ train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
31123151 if self ._enable_timer :
31133152 self .timers ("broadcast_final_loss" ).stop ()
31143153 else :
3115- # else just return all intermediate output tensor for all micro steps
3116- train_loss = self .output_tensors
3154+ # else just return logits without loss func calc
3155+ train_loss_or_logits = self .output_tensors . pop ()
31173156
31183157 if self ._clear_every_step_cache :
31193158 self ._p2p_helper .clear_meta_cache ()
@@ -3124,7 +3163,7 @@ def forward_backward_pipeline(
31243163 get_sync_logger ().info ("end forward_backward_pipeline" )
31253164 self .processed_steps + = 1
31263165 self ._check_user_hooks_status_at_step_end ()
3127- return train_loss
3166+ return train_loss_or_logits
31283167
31293168
31303169class OffloadQueue (queue .Queue ):
@@ -3187,12 +3226,12 @@ def forward_backward_pipeline(
31873226 ):
31883227 self ._reset_user_hooks_status ()
31893228 if not compute_loss :
3190- assert not forward_only , (
3191- "compute_loss can only be set to False when forward_only is set to True"
3192- )
3193- assert self . _using_cache , (
3194- "cache should be enabled for pipeline with interleave"
3195- )
3229+ assert (
3230+ forward_only
3231+ ), "compute_loss can only be set to False when forward_only is set to True"
3232+ assert (
3233+ self . _using_cache
3234+ ), "cache should be enabled for pipeline with interleave"
31963235 self .user_hooks_enabled = not forward_only
31973236 if forward_only :
31983237 return super ().forward_backward_pipeline (
@@ -3462,12 +3501,12 @@ def forward_backward_pipeline(
34623501 if self ._enable_timer :
34633502 self .timers ("broadcast_final_loss" ).start ()
34643503 with paddle .amp .auto_cast (enable = False ):
3465- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3504+ train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
34663505 if self ._enable_timer :
34673506 self .timers ("broadcast_final_loss" ).stop ()
34683507 else :
3469- # else just return all intermediate output tensor for all micro steps
3470- train_loss = self .output_tensors
3508+ # else just return logits without loss func calc
3509+ train_loss_or_logits = self .output_tensors . pop ()
34713510
34723511 if self ._clear_every_step_cache :
34733512 self ._p2p_helper .clear_meta_cache ()
@@ -3478,6 +3517,7 @@ def forward_backward_pipeline(
34783517 get_sync_logger ().info ("end forward_backward_pipeline" )
34793518 self .processed_steps + = 1
34803519 self ._check_user_hooks_status_at_step_end ()
3520+ << << < << HEAD
34813521 return train_loss
34823522
34833523
@@ -3531,3 +3571,6 @@ def convert_tensor_tuple_to_dict(input_tensor_tuple):
35313571 input_tensor_dict [key ] = tensor
35323572 delattr (tensor , "key" )
35333573 return input_tensor_dict
3574+ == == == =
3575+ return train_loss_or_logits
3576+ >> >> > >> 4 c472714c0 ([Distributed ] fix eval batch & non - compute_loss in pipeline (#73479))
0 commit comments