diff --git a/docs/source/concepts/configuring.rst b/docs/source/concepts/configuring.rst index aa584b9d..8933161d 100644 --- a/docs/source/concepts/configuring.rst +++ b/docs/source/concepts/configuring.rst @@ -104,7 +104,7 @@ You can change it like so: .. code-block:: console - python benchmarl/run.py task=vmas/balance algorithm=mappo model=layers/mlp model=layers/mlp model.layer_class="torch.nn.Linear" "model.num_cells=[32,32]" model.activation_class="torch.nn.ReLU" + python benchmarl/run.py task=vmas/balance algorithm=mappo model=layers/mlp model.layer_class="torch.nn.Linear" "model.num_cells=[32,32]" model.activation_class="torch.nn.ReLU" Available models and their configs can be found at `benchmarl/conf/model/layers `__. diff --git a/examples/extending/model/custom_model.py b/examples/extending/model/custom_model.py index 9c573166..e4689ec7 100644 --- a/examples/extending/model/custom_model.py +++ b/examples/extending/model/custom_model.py @@ -21,7 +21,7 @@ class CustomModel(Model): def __init__( self, custom_param: int, - activation_function: Type[nn.Module], + activation_class: Type[nn.Module], **kwargs, ): # Models in BenchMARL are instantiated per agent group. @@ -34,7 +34,7 @@ def __init__( # You can create your custom attributes self.custom_param = custom_param - self.activation_function = activation_function + self.activation_function = activation_class # And access some of the ones already available to your module _ = self.input_spec # Like its input_spec @@ -166,7 +166,7 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: class CustomModelConfig(ModelConfig): # The config parameters for this class, these will be loaded from yaml custom_param: int = MISSING - activation_function: Type[nn.Module] = MISSING + activation_class: Type[nn.Module] = MISSING @staticmethod def associated_class(): diff --git a/examples/extending/model/custommodel.yaml b/examples/extending/model/custommodel.yaml index ed641488..547ece39 100644 --- a/examples/extending/model/custommodel.yaml +++ b/examples/extending/model/custommodel.yaml @@ -1,4 +1,4 @@ name: custom_model custom_param: 3 -activation_function: torch.nn.Tanh +activation_class: torch.nn.Tanh