diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 3e2033ab..dfbf0dbb 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -1006,7 +1006,9 @@ def _load_experiment(self) -> Experiment: return self @staticmethod - def reload_from_file(restore_file: str) -> Experiment: + def reload_from_file( + restore_file: str, experiment_patch: Optional[Dict[str, Any]] = None + ) -> Experiment: """ Restores the experiment from the checkpoint file. @@ -1016,6 +1018,7 @@ def reload_from_file(restore_file: str) -> Experiment: Args: restore_file (str): The checkpoint file (.pt) of the experiment reload. + experiment_patch (Optional[Dict[str, Any]]): The patch to apply to the experiment config. Returns: The reloaded experiment. @@ -1036,6 +1039,11 @@ def reload_from_file(restore_file: str) -> Experiment: callbacks = pickle.load(f) task.config = task_config experiment_config.restore_file = restore_file + if experiment_patch is not None: + for key, value in experiment_patch.items(): + if not hasattr(experiment_config, key): + raise ValueError(f"Experiment config does not have attribute {key}") + setattr(experiment_config, key, value) experiment = Experiment( task=task, algorithm_config=algorithm_config,