Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ evaluation_static: False

# List of loggers to use, options are: wandb, csv, tensorboard, mflow
loggers: [csv,wandb]
# Wandb project name
project_name: "benchmarl"
# Wandb extra kwargs passed to the WandbLogger (~superset of wandb.init kwargs)
# WandbLogger includes: offline, save_dir, project, video_fps
# wandb.init includes: entity, tags, notes, etc.
wandb_extra_kwargs:
project: "benchmarl"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the only problem is that this is backwards compatibility breaking and I would like to avoid that.

What about we keep the project name out and have wandb_extra_kwargs: {}?

Then is someone provides the project in wandb_extra_kwargs and it is different from project_name we tell them through an error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get you, sounds good.

# Create a json folder as part of the output in the format of marl-eval
create_json: True

Expand Down
4 changes: 2 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ExperimentConfig:
evaluation_static: bool = MISSING

loggers: List[str] = MISSING
project_name: str = MISSING
wandb_extra_kwargs: Dict[str, Any] = MISSING
create_json: bool = MISSING

save_folder: Optional[str] = MISSING
Expand Down Expand Up @@ -611,7 +611,6 @@ def _setup_name(self):

def _setup_logger(self):
self.logger = Logger(
project_name=self.config.project_name,
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
Expand All @@ -621,6 +620,7 @@ def _setup_logger(self):
task_name=self.task_name,
group_map=self.group_map,
seed=self.seed,
wandb_extra_kwargs=self.config.wandb_extra_kwargs,
)
self.logger.log_hparams(
critic_model_name=self.critic_model_name,
Expand Down
6 changes: 3 additions & 3 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import MutableMapping, Sequence
from pathlib import Path

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
model_name: str,
group_map: Dict[str, List[str]],
seed: int,
project_name: str,
wandb_extra_kwargs: Dict[str, Any],
):
self.experiment_config = experiment_config
self.algorithm_name = algorithm_name
Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(
experiment_name=experiment_name,
wandb_kwargs={
"group": task_name,
"project": project_name,
"id": experiment_name,
**wandb_extra_kwargs,
},
)
)
Expand Down