Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,31 @@ If your task requires even more granular control over the policy/value architect



.. TODO (see https://github.com/DLR-RM/stable-baselines3/issues/113)
.. Off-Policy Algorithms
.. ^^^^^^^^^^^^^^^^^^^^^
..
.. If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
.. you can easily redefine the actor class for instance.
Off-Policy Algorithms
^^^^^^^^^^^^^^^^^^^^^

If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``,
you can pass a dictionary of the following structure: ``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.

For example, if you want a different architecture for the actor (aka ``pi``) and the critic (Q-function aka ``qf``) networks,
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.

Otherwise, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).


.. note::
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
between the actor and the critic are allowed (to prevent issues with target networks).


.. code-block:: python

from stable_baselines3 import SAC

# Custom actor architecture with two layers of 64 units each
# Custom critic architecture with two layers of 400 and 300 units
policy_kwargs = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
# Create the agent
model = SAC("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs, verbose=1)
model.learn(5000)
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Allow custom actor/critic network architectures using ``net_arch=dict(qf=[400, 300], pi=[64, 64])`` for off-policy algorithms (SAC, TD3, DDPG)

Bug Fixes:
^^^^^^^^^^
Expand Down
40 changes: 40 additions & 0 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,43 @@ def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
shared_latent = self.shared_net(features)
return self.policy_net(shared_latent), self.value_net(shared_latent)


def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
"""
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).

The ``net_arch`` parameter allows to specify the amount and size of the hidden layers,
which can be different for the actor and the critic.
It is assumed to be a list of ints or a dict.

1. If it is a list, actor and critic networks will have the same architecture.
The architecture is represented by a list of integers (of arbitrary length (zero allowed))
each specifying the number of units per layer.
If the number of ints is zero, the network will be linear.
2. If it is a dict, it should have the following structure:
``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.
where the network architecture is a list as described in 1.

For example, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).

If you want a different architecture for the actor and the critic,
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.

.. note::
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
between the actor and the critic are allowed (to prevent issues with target networks).

:param net_arch: The specification of the actor and critic networks.
See above for details on its formatting.
:return: The network architectures for the actor and the critic
"""
if isinstance(net_arch, list):
actor_arch, critic_arch = net_arch, net_arch
else:
assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict"
assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network"
assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network"
actor_arch, critic_arch = net_arch["pi"], net_arch["qf"]
return actor_arch, critic_arch
22 changes: 21 additions & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
from collections import deque
from itertools import zip_longest
from typing import Callable, Iterable, Optional, Union

import gym
Expand Down Expand Up @@ -286,6 +287,24 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
return np.nan if len(arr) == 0 else np.mean(arr)


def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
Code inspired by Stackoverflow answer for question #32954486.

:param \*iterables: iterables to ``zip()``
"""
# As in Stackoverflow #32954486, use
# new object for "empty" in case we have
# Nones in iterable.
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo


def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
Expand All @@ -303,6 +322,7 @@ def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
for param, target_param in zip(params, target_params):
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
22 changes: 15 additions & 7 deletions stable_baselines3/sac/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import gym
import torch as th
Expand All @@ -7,7 +7,13 @@
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)

# CAP the standard deviation of the actor
LOG_STD_MAX = 2
Expand Down Expand Up @@ -220,7 +226,7 @@ def __init__(
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
Expand Down Expand Up @@ -250,6 +256,8 @@ def __init__(
else:
net_arch = []

actor_arch, critic_arch = get_actor_critic_arch(net_arch)

# Create shared features extractor
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
Expand All @@ -261,7 +269,7 @@ def __init__(
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": self.net_arch,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
Expand All @@ -275,7 +283,7 @@ def __init__(
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({"n_critics": n_critics})
self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch})

self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
Expand All @@ -300,7 +308,7 @@ def _get_data(self) -> Dict[str, Any]:

data.update(
dict(
net_arch=self.net_args["net_arch"],
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
use_sde=self.actor_kwargs["use_sde"],
log_std_init=self.actor_kwargs["log_std_init"],
Expand Down Expand Up @@ -374,7 +382,7 @@ def __init__(
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
Expand Down
25 changes: 17 additions & 8 deletions stable_baselines3/td3/policies.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

import gym
import torch as th
from torch import nn

from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
NatureCNN,
create_mlp,
get_actor_critic_arch,
)


class Actor(BasePolicy):
Expand Down Expand Up @@ -101,7 +107,7 @@ def __init__(
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -127,6 +133,8 @@ def __init__(
else:
net_arch = []

actor_arch, critic_arch = get_actor_critic_arch(net_arch)

self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim

Expand All @@ -137,12 +145,13 @@ def __init__(
"action_space": self.action_space,
"features_extractor": self.features_extractor,
"features_dim": self.features_dim,
"net_arch": self.net_arch,
"net_arch": actor_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.actor_kwargs = self.net_args.copy()
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update({"n_critics": n_critics})
self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch})
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None

Expand All @@ -163,7 +172,7 @@ def _get_data(self) -> Dict[str, Any]:

data.update(
dict(
net_arch=self.net_args["net_arch"],
net_arch=self.net_arch,
activation_fn=self.net_args["activation_fn"],
n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
Expand All @@ -176,7 +185,7 @@ def _get_data(self) -> Dict[str, Any]:
return data

def make_actor(self) -> Actor:
return Actor(**self.net_args).to(self.device)
return Actor(**self.actor_kwargs).to(self.device)

def make_critic(self) -> ContinuousCritic:
return ContinuousCritic(**self.critic_kwargs).to(self.device)
Expand Down Expand Up @@ -217,7 +226,7 @@ def __init__(
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
23 changes: 17 additions & 6 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch as th

from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike


Expand All @@ -19,22 +19,33 @@
)
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(300)


@pytest.mark.parametrize("net_arch", [[4], [4, 4]])
@pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000)
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch), learning_starts=100).learn(300)


@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
kwargs = {}
if model_class in {DQN, SAC, TD3}:
kwargs = dict(learning_starts=100)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=100)

policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)
_ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs, **kwargs).learn(300)


def test_tf_like_rmsprop_optimizer():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)
_ = A2C("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(500)


def test_dqn_custom_policy():
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
6 changes: 3 additions & 3 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,18 @@ def test_n_critics(n_critics):
model = SAC(
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
)
model.learn(total_timesteps=1000)
model.learn(total_timesteps=500)


def test_dqn():
model = DQN(
"MlpPolicy",
"CartPole-v1",
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=500,
learning_starts=100,
buffer_size=500,
learning_rate=3e-4,
verbose=1,
create_eval_env=True,
)
model.learn(total_timesteps=1000, eval_freq=500)
model.learn(total_timesteps=500, eval_freq=250)
Loading