Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ evaluation_static: False

# List of loggers to use, options are: wandb, csv, tensorboard, mflow
loggers: [csv,wandb]
# Wandb project name
# Wandb project name (kept for backward compatibility)
project_name: "benchmarl"
# Wandb extra kwargs passed to the WandbLogger (~superset of wandb.init kwargs)
# WandbLogger includes: offline, save_dir, project, video_fps
# wandb.init includes: entity, tags, notes, etc.
wandb_extra_kwargs: {}
# Create a json folder as part of the output in the format of marl-eval
create_json: True

Expand Down
4 changes: 3 additions & 1 deletion benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class ExperimentConfig:

loggers: List[str] = MISSING
project_name: str = MISSING
wandb_extra_kwargs: Dict[str, Any] = MISSING
create_json: bool = MISSING

save_folder: Optional[str] = MISSING
Expand Down Expand Up @@ -611,7 +612,6 @@ def _setup_name(self):

def _setup_logger(self):
self.logger = Logger(
project_name=self.config.project_name,
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
Expand All @@ -621,6 +621,8 @@ def _setup_logger(self):
task_name=self.task_name,
group_map=self.group_map,
seed=self.seed,
project_name=self.config.project_name,
wandb_extra_kwargs=self.config.wandb_extra_kwargs,
)
self.logger.log_hparams(
critic_model_name=self.critic_model_name,
Expand Down
11 changes: 9 additions & 2 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import MutableMapping, Sequence
from pathlib import Path

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
group_map: Dict[str, List[str]],
seed: int,
project_name: str,
wandb_extra_kwargs: Dict[str, Any],
):
self.experiment_config = experiment_config
self.algorithm_name = algorithm_name
Expand All @@ -62,15 +63,21 @@ def __init__(

self.loggers: List[torchrl.record.loggers.Logger] = []
for logger_name in experiment_config.loggers:
wandb_project = wandb_extra_kwargs.get("project", project_name)
if wandb_project != project_name:
raise ValueError(
f"wandb_extra_kwargs.project ({wandb_project}) is different from the project_name ({project_name})"
)
self.loggers.append(
get_logger(
logger_type=logger_name,
logger_name=folder_name,
experiment_name=experiment_name,
wandb_kwargs={
"group": task_name,
"project": project_name,
"id": experiment_name,
"project": project_name,
**wandb_extra_kwargs,
},
)
)
Expand Down
7 changes: 4 additions & 3 deletions benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> TaskClass:
cfg_dict_checked = OmegaConf.to_object(cfg)
if is_dataclass(cfg_dict_checked):
cfg_dict_checked = cfg_dict_checked.__dict__
cfg_dict_checked = _type_check_task_config(
environment_name, inner_task_name, cfg_dict_checked
) # Only needed for the warning
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to run the check anyway no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I commited a modification linked to another upcoming request 🙈

cfg_dict_checked = _type_check_task_config(
environment_name, inner_task_name, cfg_dict_checked
) # Only needed for the warning
return task_config_registry[task_name].get_task(cfg_dict_checked)


Expand Down
Loading