Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion benchmarl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import argparse
from pathlib import Path

from experiment import Experiment

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":
Expand All @@ -17,5 +19,8 @@
)
args = parser.parse_args()
checkpoint_file = str(Path(args.checkpoint_file).resolve())
experiment = reload_experiment_from_file(checkpoint_file)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for readability i'd suggest adapting the hydra_config util such that you can do:

if hydra_config.checkpoint_from_hydra_run(checkpoint_file)
    experiment = hydra_config.reload_experiment_from_file()
else:
    # Assume experiment can be loaded directly. 
    experiment = Experiment.reload_from_file(checkpoint_file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense in the medium term to maintain both hydra-based and non-hydra based saving/loading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense in the medium term to maintain both hydra-based and non-hydra based saving/loading?

I would think so, the hydra one is more interpretable so it should be the first choice if you can

experiment = reload_experiment_from_file(checkpoint_file)
except ValueError:
experiment = Experiment.reload_from_file(checkpoint_file)
experiment.evaluate()
60 changes: 54 additions & 6 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import importlib

import os
import pickle
import shutil
import time
import warnings
Expand Down Expand Up @@ -582,12 +583,16 @@ def _setup_name(self):
self.name = Path(self.config.restore_file).parent.parent.resolve().name
self.folder_name = save_folder / self.name

if (
len(self.config.loggers)
or self.config.checkpoint_interval > 0
or self.config.create_json
):
self.folder_name.mkdir(parents=False, exist_ok=True)
self.folder_name.mkdir(parents=False, exist_ok=True)
with open(self.folder_name / "config.pkl", "wb") as f:
pickle.dump(self.task, f)
pickle.dump(self.task.config if self.task.config is not None else {}, f)
pickle.dump(self.algorithm_config, f)
pickle.dump(self.model_config, f)
pickle.dump(self.seed, f)
pickle.dump(self.config, f)
pickle.dump(self.critic_model_config, f)
pickle.dump(self.callbacks, f)

def _setup_logger(self):
self.logger = Logger(
Expand Down Expand Up @@ -957,3 +962,46 @@ def _load_experiment(self) -> Experiment:
)
self.load_state_dict(loaded_dict)
return self

@staticmethod
def reload_from_file(restore_file: str) -> Experiment:
"""
Restores the experiment from the checkpoint file.

If expects the same folder structure created when an experiment is run.
The checkpoint file (``restore_file``) is in the checkpoints directory and a config.pkl file is
present a level above at restore_file/../../config.pkl

Args:
restore_file (str): The checkpoint file (.pt) of the experiment reload.

Returns:
The reloaded experiment.

"""
experiment_folder = Path(restore_file).parent.parent.resolve()
config_file = experiment_folder / "config.pkl"
if not os.path.exists(config_file):
raise ValueError("config.pkl file not found in experiment folder.")
with open(config_file, "rb") as f:
task = pickle.load(f)
task_config = pickle.load(f)
algorithm_config = pickle.load(f)
model_config = pickle.load(f)
seed = pickle.load(f)
experiment_config = pickle.load(f)
critic_model_config = pickle.load(f)
callbacks = pickle.load(f)
task.config = task_config
experiment_config.restore_file = restore_file
experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
seed=seed,
config=experiment_config,
callbacks=callbacks,
critic_model_config=critic_model_config,
)
print(f"\nReloaded experiment {experiment.name} from {restore_file}.")
return experiment
7 changes: 6 additions & 1 deletion benchmarl/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import argparse
from pathlib import Path

from experiment import Experiment

from benchmarl.hydra_config import reload_experiment_from_file

if __name__ == "__main__":
Expand All @@ -18,5 +20,8 @@
args = parser.parse_args()
checkpoint_file = str(Path(args.checkpoint_file).resolve())

experiment = reload_experiment_from_file(checkpoint_file)
try:
experiment = reload_experiment_from_file(checkpoint_file)
except ValueError:
experiment = Experiment.reload_from_file(checkpoint_file)
experiment.run()
4 changes: 4 additions & 0 deletions examples/checkpointing/reload_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
# Let's do 3 more iters
experiment_config.max_n_iters += 3

# We can also change part of the configuration (algorithm, task). For example to evaluate in a new task.
experiment = Experiment(
algorithm_config=algorithm_config,
model_config=model_config,
Expand All @@ -57,3 +58,6 @@
task=task,
)
experiment.run()

# We can also evaluate
experiment.evaluate()
54 changes: 54 additions & 0 deletions examples/checkpointing/resume_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
from pathlib import Path

from benchmarl.algorithms import MappoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models.mlp import MlpConfig

if __name__ == "__main__":

experiment_config = ExperimentConfig.get_from_yaml()
# Save the experiment in the current folder
experiment_config.save_folder = Path(os.path.dirname(os.path.realpath(__file__)))
# Checkpoint at every iteration
experiment_config.checkpoint_interval = (
experiment_config.on_policy_collected_frames_per_batch
)
# Run 3 iterations
experiment_config.max_n_iters = 3

task = VmasTask.BALANCE.get_from_yaml()
algorithm_config = MappoConfig.get_from_yaml()
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()
experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
experiment.run()

# Now we tell it where to resume from
restored_experiment = Experiment.reload_from_file(
(
experiment.folder_name
/ "checkpoints"
/ f"checkpoint_{experiment_config.checkpoint_interval}.pt"
)
) # Restore from first checkpoint

# We keep te same configuration
restored_experiment.run()

# We can also evaluate
restored_experiment.evaluate()
Loading