diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index ee800697..7f77078e 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -101,8 +101,12 @@ evaluation_static: False # List of loggers to use, options are: wandb, csv, tensorboard, mflow loggers: [csv,wandb] -# Wandb project name +# Wandb project name (kept for backward compatibility) project_name: "benchmarl" +# Wandb extra kwargs passed to the WandbLogger (~superset of wandb.init kwargs) +# WandbLogger includes: offline, save_dir, project, video_fps +# wandb.init includes: entity, tags, notes, etc. +wandb_extra_kwargs: {} # Create a json folder as part of the output in the format of marl-eval create_json: True diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 5f61ce1a..fa2061ee 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -111,6 +111,7 @@ class ExperimentConfig: loggers: List[str] = MISSING project_name: str = MISSING + wandb_extra_kwargs: Dict[str, Any] = MISSING create_json: bool = MISSING save_folder: Optional[str] = MISSING @@ -611,7 +612,6 @@ def _setup_name(self): def _setup_logger(self): self.logger = Logger( - project_name=self.config.project_name, experiment_name=self.name, folder_name=str(self.folder_name), experiment_config=self.config, @@ -621,6 +621,8 @@ def _setup_logger(self): task_name=self.task_name, group_map=self.group_map, seed=self.seed, + project_name=self.config.project_name, + wandb_extra_kwargs=self.config.wandb_extra_kwargs, ) self.logger.log_hparams( critic_model_name=self.critic_model_name, diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 418a868d..70ef84ee 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -10,7 +10,7 @@ from collections.abc import MutableMapping, Sequence from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -39,6 +39,7 @@ def __init__( group_map: Dict[str, List[str]], seed: int, project_name: str, + wandb_extra_kwargs: Dict[str, Any], ): self.experiment_config = experiment_config self.algorithm_name = algorithm_name @@ -62,6 +63,11 @@ def __init__( self.loggers: List[torchrl.record.loggers.Logger] = [] for logger_name in experiment_config.loggers: + wandb_project = wandb_extra_kwargs.get("project", project_name) + if wandb_project != project_name: + raise ValueError( + f"wandb_extra_kwargs.project ({wandb_project}) is different from the project_name ({project_name})" + ) self.loggers.append( get_logger( logger_type=logger_name, @@ -69,8 +75,9 @@ def __init__( experiment_name=experiment_name, wandb_kwargs={ "group": task_name, - "project": project_name, "id": experiment_name, + "project": project_name, + **wandb_extra_kwargs, }, ) )