@@ -1027,7 +1027,9 @@ def train_batch(
10271027
10281028 return train_loss
10291029
1030- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False ):
1030+ def eval_batch (
1031+ self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False
1032+ ):
10311033 self .user_hooks_enabled = False
10321034 # reset the virtual pp rank for each run
10331035 self .set_virtual_pipeline_rank (0 )
@@ -1055,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
10551057 startup_steps = min (startup_steps , self .accumulate_steps )
10561058 steady_steps = self .accumulate_steps - startup_steps
10571059
1058- input_buffers = []
10591060 output_buffers = []
10601061
10611062 # convert to micro dataset
@@ -1076,9 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
10761077 skip_check_meta = True ,
10771078 batch_p2p_comm = self ._use_batch_p2p_comm ,
10781079 )
1079- self ._offload_tensors (output_tensor )
1080+ if not self .is_pipeline_last_stage ():
1081+ self ._release_output (output_tensor )
1082+ else :
1083+ self ._offload_tensors (output_tensor )
10801084
1081- input_buffers .append (input_tensor )
10821085 output_buffers .append (output_tensor )
10831086
10841087 if steady_steps > 0 :
@@ -1099,9 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
10991102 skip_check_meta = True ,
11001103 batch_p2p_comm = self ._use_batch_p2p_comm ,
11011104 )
1102- self ._offload_tensors (output_tensor )
1105+ if not self .is_pipeline_last_stage ():
1106+ self ._release_output (output_tensor )
1107+ else :
1108+ self ._offload_tensors (output_tensor )
11031109
1104- input_buffers .append (input_tensor )
11051110 output_buffers .append (output_tensor )
11061111
11071112 if not last_iter :
@@ -1437,10 +1442,16 @@ def _offload_tensors(self, output_tensor):
14371442 return
14381443 if isinstance (output_tensor , (tuple , list )):
14391444 for t in output_tensor :
1440- host_tensor = t .pin_memory () if hasattr (t , "pin_memory" ) else t .cpu ()
1445+ host_tensor = (
1446+ t .pin_memory () if hasattr (t , "pin_memory" ) else t .cpu ()
1447+ )
14411448 host_tensor ._share_buffer_to (t )
14421449 else :
1443- host_tensor = output_tensor .pin_memory () if hasattr (output_tensor , "pin_memory" ) else output_tensor .cpu ()
1450+ host_tensor = (
1451+ output_tensor .pin_memory ()
1452+ if hasattr (output_tensor , "pin_memory" )
1453+ else output_tensor .cpu ()
1454+ )
14441455 host_tensor ._share_buffer_to (output_tensor )
14451456
14461457 def _release_output (self , output ):
@@ -2827,7 +2838,9 @@ def backward_async_comm(
28272838 if self ._enable_timer :
28282839 self .timers ("broadcast_final_loss" ).start ()
28292840 with paddle .amp .auto_cast (enable = False ):
2830- train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
2841+ train_loss_or_logits = self ._broadcast_final_loss (
2842+ return_micro_batch_loss
2843+ )
28312844 if self ._enable_timer :
28322845 self .timers ("broadcast_final_loss" ).stop ()
28332846 else :
@@ -2881,7 +2894,9 @@ def train_batch(
28812894
28822895 return train_loss
28832896
2884- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False ):
2897+ def eval_batch (
2898+ self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False
2899+ ):
28852900 self .user_hooks_enabled = False
28862901 # reset the virtual pp rank for each run
28872902 self .set_virtual_pipeline_rank (0 )
@@ -2899,7 +2914,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor
28992914 ), f"loss function { loss_fn_idx } should exist to compute loss"
29002915 self .loss_fn_idx = loss_fn_idx
29012916
2902- train_loss_or_logits = self .forward_backward_pipeline (data , None , forward_only = True , compute_loss = compute_loss )
2917+ train_loss_or_logits = self .forward_backward_pipeline (
2918+ data , None , forward_only = True , compute_loss = compute_loss
2919+ )
29032920 self ._init_buffers ()
29042921 self ._compute_loss = origin_compute_loss
29052922 self ._return_host_tensor = origin_return_host_tensor
@@ -2993,9 +3010,9 @@ def forward_backward_pipeline(
29933010 if self .processed_steps < g_profile_pipeline_details_steps :
29943011 get_sync_logger ().info ("start forward_backward_pipeline" )
29953012 if not compute_loss :
2996- assert (
2997- forward_only
2998- ), "compute_loss can only be set to False when forward_only is set to True"
3013+ assert forward_only , (
3014+ "compute_loss can only be set to False when forward_only is set to True"
3015+ )
29993016
30003017 # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
30013018 assert self ._using_cache , (
@@ -3011,15 +3028,10 @@ def forward_backward_pipeline(
30113028
30123029 assert (
30133030 self .accumulate_steps == self .num_stages
3014- << < << << HEAD
3015- or self .accumulate_steps % self .num_stages != 0
3031+ or self .accumulate_steps % self .num_stages == 0
30163032 ), (
30173033 f"accumulate_steps({ self .accumulate_steps } ) and num_stages({ self .num_stages } ) should be a multiple or accumulate_steps % num_stages == 0"
30183034 )
3019- == == == =
3020- or self .accumulate_steps % self .num_stages == 0
3021- ), f"accumulate_steps({ self .accumulate_steps } ) and num_stages({ self .num_stages } ) should be a multiple or accumulate_steps % num_stages == 0"
3022- >> >> >> > 4 c472714c0 ([Distributed ] fix eval batch & non - compute_loss in pipeline (#73479))
30233035
30243036 self ._backward_step_count = 0
30253037 skip_steps = self .accumulate_steps - self .num_stages
@@ -3147,7 +3159,9 @@ def forward_backward_pipeline(
31473159 if self ._enable_timer :
31483160 self .timers ("broadcast_final_loss" ).start ()
31493161 with paddle .amp .auto_cast (enable = False ):
3150- train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
3162+ train_loss_or_logits = self ._broadcast_final_loss (
3163+ return_micro_batch_loss
3164+ )
31513165 if self ._enable_timer :
31523166 self .timers ("broadcast_final_loss" ).stop ()
31533167 else :
@@ -3226,12 +3240,12 @@ def forward_backward_pipeline(
32263240 ):
32273241 self ._reset_user_hooks_status ()
32283242 if not compute_loss :
3229- assert (
3230- forward_only
3231- ), "compute_loss can only be set to False when forward_only is set to True"
3232- assert (
3233- self . _using_cache
3234- ), "cache should be enabled for pipeline with interleave"
3243+ assert forward_only , (
3244+ "compute_loss can only be set to False when forward_only is set to True"
3245+ )
3246+ assert self . _using_cache , (
3247+ "cache should be enabled for pipeline with interleave"
3248+ )
32353249 self .user_hooks_enabled = not forward_only
32363250 if forward_only :
32373251 return super ().forward_backward_pipeline (
@@ -3501,7 +3515,9 @@ def forward_backward_pipeline(
35013515 if self ._enable_timer :
35023516 self .timers ("broadcast_final_loss" ).start ()
35033517 with paddle .amp .auto_cast (enable = False ):
3504- train_loss_or_logits = self ._broadcast_final_loss (return_micro_batch_loss )
3518+ train_loss_or_logits = self ._broadcast_final_loss (
3519+ return_micro_batch_loss
3520+ )
35053521 if self ._enable_timer :
35063522 self .timers ("broadcast_final_loss" ).stop ()
35073523 else :
@@ -3517,8 +3533,7 @@ def forward_backward_pipeline(
35173533 get_sync_logger ().info ("end forward_backward_pipeline" )
35183534 self .processed_steps += 1
35193535 self ._check_user_hooks_status_at_step_end ()
3520- << << < << HEAD
3521- return train_loss
3536+ return train_loss_or_logits
35223537
35233538
35243539def tuple_to_dict_helper (input_tensor ):
@@ -3571,6 +3586,3 @@ def convert_tensor_tuple_to_dict(input_tensor_tuple):
35713586 input_tensor_dict [key ] = tensor
35723587 delattr (tensor , "key" )
35733588 return input_tensor_dict
3574- == == == =
3575- return train_loss_or_logits
3576- >> >> > >> 4 c472714c0 ([Distributed ] fix eval batch & non - compute_loss in pipeline (#73479))
0 commit comments