@@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy):
432432 self .loss_fn_idx = 0
433433
434434 self ._compute_loss = True
435+ self ._return_host_tensor = False
435436 self .callbacks = pipeline_parallel_callbacks_
436437
437438 logger .info (
@@ -1026,13 +1027,18 @@ def train_batch(
10261027
10271028 return train_loss
10281029
1029- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
1030+ def eval_batch (
1031+ self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False
1032+ ):
10301033 self .user_hooks_enabled = False
10311034 # reset the virtual pp rank for each run
10321035 self .set_virtual_pipeline_rank (0 )
10331036
10341037 self ._layers .eval ()
1038+ origin_compute_loss = self ._compute_loss
10351039 self ._compute_loss = compute_loss
1040+ origin_return_host_tensor = self ._return_host_tensor
1041+ self ._return_host_tensor = return_host_tensor
10361042
10371043 # store data id for micro_batch
10381044 self .micro_batch_id = 0
@@ -1051,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10511057 startup_steps = min (startup_steps , self .accumulate_steps )
10521058 steady_steps = self .accumulate_steps - startup_steps
10531059
1054- input_buffers = []
10551060 output_buffers = []
10561061
10571062 # convert to micro dataset
@@ -1072,8 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10721077 skip_check_meta = True ,
10731078 batch_p2p_comm = self ._use_batch_p2p_comm ,
10741079 )
1080+ if not self .is_pipeline_last_stage ():
1081+ self ._release_output (output_tensor )
1082+ else :
1083+ self ._offload_tensors (output_tensor )
10751084
1076- input_buffers .append (input_tensor )
10771085 output_buffers .append (output_tensor )
10781086
10791087 if steady_steps > 0 :
@@ -1094,8 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10941102 skip_check_meta = True ,
10951103 batch_p2p_comm = self ._use_batch_p2p_comm ,
10961104 )
1105+ if not self .is_pipeline_last_stage ():
1106+ self ._release_output (output_tensor )
1107+ else :
1108+ self ._offload_tensors (output_tensor )
10971109
1098- input_buffers .append (input_tensor )
10991110 output_buffers .append (output_tensor )
11001111
11011112 if not last_iter :
@@ -1105,11 +1116,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
11051116 )
11061117
11071118 if self ._compute_loss :
1108- self . train_loss = self ._broadcast_final_loss ()
1119+ train_loss = self ._broadcast_final_loss ()
11091120 else :
1110- self . train_loss = output_buffers
1121+ train_loss = output_buffers
11111122
1112- return self .train_loss
1123+ self ._compute_loss = origin_compute_loss
1124+ self ._return_host_tensor = origin_return_host_tensor
1125+ return train_loss
11131126
11141127 def _maybe_loss_compute (
11151128 self , output_tensor , micro_dataset , overlap_schedule_mode = False
@@ -1424,6 +1437,23 @@ def _optimizer_step(self):
14241437 if self .lr_scheduler :
14251438 self .lr_scheduler .step ()
14261439
1440+ def _offload_tensors (self , output_tensor ):
1441+ if not self ._return_host_tensor :
1442+ return
1443+ if isinstance (output_tensor , (tuple , list )):
1444+ for t in output_tensor :
1445+ host_tensor = (
1446+ t .pin_memory () if hasattr (t , "pin_memory" ) else t .cpu ()
1447+ )
1448+ host_tensor ._share_buffer_to (t )
1449+ else :
1450+ host_tensor = (
1451+ output_tensor .pin_memory ()
1452+ if hasattr (output_tensor , "pin_memory" )
1453+ else output_tensor .cpu ()
1454+ )
1455+ host_tensor ._share_buffer_to (output_tensor )
1456+
14271457 def _release_output (self , output ):
14281458 def can_free (t ):
14291459 return (
@@ -1694,10 +1724,12 @@ def _get_forward_input(self, virtual_pp_rank):
16941724 assert hasattr (self , 'output_tensors' )
16951725 if not self ._forward_only :
16961726 assert hasattr (self , 'output_tensor_grads' )
1697- assert len (self .input_tensors [virtual_pp_rank ]) == (
1698- len (self .output_tensors [virtual_pp_rank ]) + 1
1699- )
1700- input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1727+ assert len (self .input_tensors [virtual_pp_rank ]) == (
1728+ len (self .output_tensors [virtual_pp_rank ]) + 1
1729+ )
1730+ input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1731+ else :
1732+ input_tensor = self .input_tensors [virtual_pp_rank ].pop ()
17011733 return input_tensor
17021734
17031735 def _store_forward_outputs (
@@ -1712,11 +1744,17 @@ def _store_forward_outputs(
17121744 self .schedule_chunks [virtual_pp_rank ].append (schedule_chunk )
17131745 if self .is_pipeline_last_stage ():
17141746 self .loss_fn_chunks .append (loss_fn_node )
1715-
1716- if self ._forward_only :
1747+ if self ._forward_only :
1748+ # no need to store tensor for backward
1749+ if self ._compute_loss :
1750+ self .output_tensors [virtual_pp_rank ].pop ()
1751+ # save output_tensors for return value of eval batch
1752+ else :
1753+ self ._offload_tensors (output_tensor )
1754+ else :
17171755 # no need to store tensor for backward
1718- self .input_tensors [ virtual_pp_rank ]. pop ()
1719- self .output_tensors [virtual_pp_rank ].pop ()
1756+ if self ._forward_only :
1757+ self .output_tensors [virtual_pp_rank ].pop ()
17201758
17211759 def _forward_step_helper (
17221760 self ,
@@ -2022,7 +2060,7 @@ def forward_backward_pipeline(
20222060 # this strategy is inspired by:
20232061 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
20242062 if not compute_loss :
2025- assert not forward_only , (
2063+ assert forward_only , (
20262064 "compute_loss can only be set to False when forward_only is set to True"
20272065 )
20282066
@@ -2669,7 +2707,7 @@ def backward_async_comm(
26692707
26702708 # no steady steps, which only occurs when accumulate_step == num_stage
26712709 if not steady_steps :
2672- output_tensor_grad = p2p .recv_backward (
2710+ output_tensor_grad = self . _p2p_helper .recv_backward (
26732711 self .is_pipeline_last_stage (),
26742712 batch_p2p_comm = self ._use_batch_p2p_comm ,
26752713 )
@@ -2800,12 +2838,14 @@ def backward_async_comm(
28002838 if self ._enable_timer :
28012839 self .timers ("broadcast_final_loss" ).start ()
28022840 with paddle .amp .auto_cast (enable = False ):
2803- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
2841+ train_loss_or_logits = self ._broadcast_final_loss (
2842+ return_micro_batch_loss
2843+ )
28042844 if self ._enable_timer :
28052845 self .timers ("broadcast_final_loss" ).stop ()
28062846 else :
2807- # else just return all intermediate output tensor for all micro steps
2808- train_loss = self .output_tensors
2847+ # else just return logits without loss func calc
2848+ train_loss_or_logits = self .output_tensors . pop ()
28092849
28102850 if self ._clear_every_step_cache :
28112851 self ._p2p_helper .clear_meta_cache ()
@@ -2823,7 +2863,7 @@ def backward_async_comm(
28232863 ), "p2p dynamic_cnt should equal to send_recv_meta_list"
28242864 self ._p2p_helper ._dynamic_cnt = 0
28252865
2826- return train_loss
2866+ return train_loss_or_logits
28272867
28282868 def train_batch (
28292869 self ,
@@ -2854,13 +2894,18 @@ def train_batch(
28542894
28552895 return train_loss
28562896
2857- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
2897+ def eval_batch (
2898+ self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False
2899+ ):
28582900 self .user_hooks_enabled = False
28592901 # reset the virtual pp rank for each run
28602902 self .set_virtual_pipeline_rank (0 )
28612903
28622904 self ._layers .eval ()
2905+ origin_compute_loss = self ._compute_loss
28632906 self ._compute_loss = compute_loss
2907+ origin_return_host_tensor = self ._return_host_tensor
2908+ self ._return_host_tensor = return_host_tensor
28642909
28652910 # check loss_fn_idx is valid and loss_fn exists
28662911 assert (
@@ -2869,7 +2914,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
28692914 ), f"loss function { loss_fn_idx } should exist to compute loss"
28702915 self .loss_fn_idx = loss_fn_idx
28712916
2872- return self .forward_backward_pipeline (data , None , forward_only = True )
2917+ train_loss_or_logits = self .forward_backward_pipeline (
2918+ data , None , forward_only = True , compute_loss = compute_loss
2919+ )
2920+ self ._init_buffers ()
2921+ self ._compute_loss = origin_compute_loss
2922+ self ._return_host_tensor = origin_return_host_tensor
2923+ return train_loss_or_logits
28732924
28742925 def get_static_scheduler (self ):
28752926 return self .forward_backward_pipeline (
@@ -2959,7 +3010,7 @@ def forward_backward_pipeline(
29593010 if self .processed_steps < g_profile_pipeline_details_steps :
29603011 get_sync_logger ().info ("start forward_backward_pipeline" )
29613012 if not compute_loss :
2962- assert not forward_only , (
3013+ assert forward_only , (
29633014 "compute_loss can only be set to False when forward_only is set to True"
29643015 )
29653016
@@ -2977,7 +3028,7 @@ def forward_backward_pipeline(
29773028
29783029 assert (
29793030 self .accumulate_steps == self .num_stages
2980- or self .accumulate_steps % self .num_stages ! = 0
3031+ or self .accumulate_steps % self .num_stages = = 0
29813032 ), (
29823033 f"accumulate_steps({ self .accumulate_steps } ) and num_stages({ self .num_stages } ) should be a multiple or accumulate_steps % num_stages == 0"
29833034 )
@@ -3108,12 +3159,14 @@ def forward_backward_pipeline(
31083159 if self ._enable_timer :
31093160 self .timers ("broadcast_final_loss" ).start ()
31103161 with paddle .amp .auto_cast (enable = False ):
3111- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3162+ train_loss_or_logits = self ._broadcast_final_loss (
3163+ return_micro_batch_loss
3164+ )
31123165 if self ._enable_timer :
31133166 self .timers ("broadcast_final_loss" ).stop ()
31143167 else :
3115- # else just return all intermediate output tensor for all micro steps
3116- train_loss = self .output_tensors
3168+ # else just return logits without loss func calc
3169+ train_loss_or_logits = self .output_tensors . pop ()
31173170
31183171 if self ._clear_every_step_cache :
31193172 self ._p2p_helper .clear_meta_cache ()
@@ -3124,7 +3177,7 @@ def forward_backward_pipeline(
31243177 get_sync_logger ().info ("end forward_backward_pipeline" )
31253178 self .processed_steps += 1
31263179 self ._check_user_hooks_status_at_step_end ()
3127- return train_loss
3180+ return train_loss_or_logits
31283181
31293182
31303183class OffloadQueue (queue .Queue ):
@@ -3187,7 +3240,7 @@ def forward_backward_pipeline(
31873240 ):
31883241 self ._reset_user_hooks_status ()
31893242 if not compute_loss :
3190- assert not forward_only , (
3243+ assert forward_only , (
31913244 "compute_loss can only be set to False when forward_only is set to True"
31923245 )
31933246 assert self ._using_cache , (
@@ -3462,12 +3515,14 @@ def forward_backward_pipeline(
34623515 if self ._enable_timer :
34633516 self .timers ("broadcast_final_loss" ).start ()
34643517 with paddle .amp .auto_cast (enable = False ):
3465- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3518+ train_loss_or_logits = self ._broadcast_final_loss (
3519+ return_micro_batch_loss
3520+ )
34663521 if self ._enable_timer :
34673522 self .timers ("broadcast_final_loss" ).stop ()
34683523 else :
3469- # else just return all intermediate output tensor for all micro steps
3470- train_loss = self .output_tensors
3524+ # else just return logits without loss func calc
3525+ train_loss_or_logits = self .output_tensors . pop ()
34713526
34723527 if self ._clear_every_step_cache :
34733528 self ._p2p_helper .clear_meta_cache ()
@@ -3478,7 +3533,7 @@ def forward_backward_pipeline(
34783533 get_sync_logger ().info ("end forward_backward_pipeline" )
34793534 self .processed_steps += 1
34803535 self ._check_user_hooks_status_at_step_end ()
3481- return train_loss
3536+ return train_loss_or_logits
34823537
34833538
34843539def tuple_to_dict_helper (input_tensor ):
0 commit comments