@@ -23,6 +23,7 @@ def __init__(
2323 save_gdrive ,
2424 avg_loss_smoothing ,
2525 is_gpu_used ,
26+ custom_callbacks
2627 ):
2728 self .training_bar = None
2829 self .model = model
@@ -39,6 +40,7 @@ def __init__(
3940 self .steps = 0
4041 self .current_loss = None
4142 self .prev_avg_loss = None
43+ self .custom_callbacks = custom_callbacks
4244
4345 @property
4446 def save_every_check (self ):
@@ -54,17 +56,22 @@ def on_train_begin(self, args, state, control, **kwargs):
5456 file = sys .stdout ,
5557 )
5658
59+ self .custom_callbacks .get ('on_train_begin' , lambda : None )()
60+
5761 def on_train_end (self , args , state , control , ** kwargs ):
5862 if state .is_local_process_zero :
5963 self .training_bar .close ()
6064 self .training_bar = None
6165
66+ self .custom_callbacks .get ('on_train_end' , lambda : None )()
67+
6268 def on_evaluate (self , args , state , control , metrics , ** kwargs ):
69+ print (f'''on_evaluate called: train_loss={ metrics .get ('train_loss' )} ''' )
6370 if state .is_local_process_zero :
6471 self .current_loss = float (metrics .get ("train_loss" ))
6572
6673 def on_step_end (self , args , state , control , ** kwargs ):
67-
74+ print ( f'on_step_end called \n args= { args } \n state= { state } \n control= { control } \n kwargs= { kwargs } \n =====' )
6875 if state .is_local_process_zero :
6976 self .steps += 1
7077 avg_loss = 0
@@ -100,6 +107,8 @@ def on_step_end(self, args, state, control, **kwargs):
100107 if self .current_loss :
101108 self .training_bar .set_description (desc )
102109
110+ self .custom_callbacks .get ('on_step_end' , lambda steps , max , curr , avg , trainer : None )(self .steps , state .max_steps , self .current_loss , avg_loss , self .trainer )
111+
103112 if self .save_every > 0 and self .steps % self .save_every == 0 :
104113 self .save_pytorch_model ()
105114
@@ -138,6 +147,8 @@ def generate_sample_text(self):
138147
139148 self .training_bar .write ("=" * 10 )
140149
150+ self .custom_callbacks .get ('on_sample_text_generated' , lambda texts : None )(gen_texts )
151+
141152 def save_pytorch_model (self ):
142153 # only runs on state.is_local_process_zero
143154 self .training_bar .write (
@@ -154,6 +165,8 @@ def save_pytorch_model(self):
154165 os .path .join ("/content/drive/My Drive/" , self .run_id , pt_file ),
155166 )
156167
168+ self .custom_callbacks .get ('on_model_saved' , lambda curr , out : None )(self .steps , self .output_dir )
169+
157170 def average_loss (self , current_loss , prev_avg_loss , smoothing ):
158171 if prev_avg_loss is None :
159172 return current_loss
0 commit comments