diff --git a/search_params.py b/search_params.py index 5efc5c47..7acec100 100644 --- a/search_params.py +++ b/search_params.py @@ -2,7 +2,6 @@ import json import logging import os -import time from datetime import datetime from pathlib import Path @@ -234,6 +233,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val): logging.info(f"Re-training with best config: \n{best_config}") trainer = TorchTrainer(config=best_config, **data) trainer.train() + best_model_path = trainer.checkpoint_callback.last_model_path else: # If not merging training and validation data, load the best result from tune experiments. logging.info(f"Loading best model with best config: \n{best_config}") @@ -241,15 +241,14 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val): best_checkpoint = os.path.join(best_log_dir, "best_model.ckpt") last_checkpoint = os.path.join(best_log_dir, "last.ckpt") trainer._setup_model(checkpoint_path=best_checkpoint) - os.popen(f"cp {best_checkpoint} {os.path.join(checkpoint_dir, 'best_model.ckpt')}") + best_model_path = os.path.join(checkpoint_dir, 'best_model.ckpt') + os.popen(f"cp {best_checkpoint} {best_model_path}") os.popen(f"cp {last_checkpoint} {os.path.join(checkpoint_dir, 'last.ckpt')}") if "test" in data["datasets"]: test_results = trainer.test() logging.info(f"Test results after re-training: {test_results}") - logging.info( - f"Best model saved to {trainer.checkpoint_callback.best_model_path or trainer.checkpoint_callback.last_model_path}." - ) + logging.info(f"Best model saved to {best_model_path}.") def main(): @@ -341,7 +340,7 @@ def main(): # Save best model after parameter search. best_config = analysis.get_best_config(f"val_{config.val_metric}", config.mode, scope="all") best_log_dir = analysis.get_best_logdir(f"val_{config.val_metric}", config.mode, scope="all") - retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not args.no_merge_train_val) + retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not config.no_merge_train_val) if __name__ == "__main__":