Skip to content

Commit 752d97d

Browse files
authored
Add on_fit_epoch_end callback (#5232)
* Add `on_fit_epoch_end` callback * Add results to train * Update __init__.py
1 parent 13f7275 commit 752d97d

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
423423
plots=True,
424424
callbacks=callbacks,
425425
compute_loss=compute_loss) # val best model with plots
426+
if is_coco:
427+
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
426428

427-
callbacks.run('on_train_end', last, best, plots, epoch)
429+
callbacks.run('on_train_end', last, best, plots, epoch, results)
428430
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
429431

430432
torch.cuda.empty_cache()

utils/loggers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
131131
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
132132
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
133133

134-
def on_train_end(self, last, best, plots, epoch):
134+
def on_train_end(self, last, best, plots, epoch, results):
135135
# Callback runs on training end
136136
if plots:
137137
plot_results(file=self.save_dir / 'results.csv') # save results.png

0 commit comments

Comments
 (0)