Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f779a9f
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 7, 2021
98bc5b2
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 9, 2021
97ece67
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 17, 2021
799b140
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 17, 2021
dc73462
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 19, 2021
9b8a222
feat: TRPO - addressing PR comments
cyprienc Sep 11, 2021
869dce9
refactor: TRPO - policier
cyprienc Sep 11, 2021
347dcc0
feat: using updated ActorCriticPolicy from SB3
cyprienc Sep 11, 2021
35d7256
Bump version for `get_distribution` support
araffin Sep 13, 2021
9cfcb54
Add basic test
araffin Sep 13, 2021
974174a
Reformat
araffin Sep 13, 2021
b6bd449
[ci skip] Fix changelog
araffin Sep 13, 2021
c88951c
fix: setting train mode for trpo
cyprienc Sep 13, 2021
1f7e99d
fix: batch_size type hint in trpo.py
cyprienc Sep 13, 2021
6540371
style: renaming variables + docstring in trpo.py
cyprienc Sep 15, 2021
3a26c05
Merge branch 'master' into master
araffin Sep 23, 2021
f003e88
Merge branch 'master' into master
araffin Sep 27, 2021
a33409e
Merge branch 'master' into master
araffin Sep 29, 2021
8ecf40e
Rename + cleanup
araffin Sep 29, 2021
45f4ea6
Move grad computation to separate method
araffin Sep 29, 2021
cc4b5ab
Remove grad norm clipping
araffin Sep 29, 2021
fc7a6c7
Remove n epochs and add sub-sampling
araffin Sep 29, 2021
66723ff
Update defaults
araffin Sep 29, 2021
63a263f
Merge branch 'master' into master
araffin Dec 1, 2021
bf583de
Merge branch 'master' into cyprienc/master
araffin Dec 10, 2021
e983348
Add Doc
araffin Dec 27, 2021
439d79b
Add more test and fixes for CNN
araffin Dec 27, 2021
d9483dc
Update doc + add benchmark
araffin Dec 28, 2021
fff84e4
Add tests + update doc
araffin Dec 28, 2021
95dddf4
Fix doc
araffin Dec 28, 2021
661fe15
Improve names for conjugate gradient
araffin Dec 29, 2021
a24e7c0
Update comments
araffin Dec 29, 2021
342fe53
Update changelog
araffin Dec 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ See documentation for the full list of included features.
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
7 changes: 7 additions & 0 deletions docs/common/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _utils:

Utils
=====

.. automodule:: sb3_contrib.common.utils
:members:
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
TQC ✔️ ❌ ❌ ❌ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
============ =========== ============ ================= =============== ================


Expand Down
13 changes: 13 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ Train a PPO with invalid action masking agent on a toy environment.
model = MaskablePPO("MlpPolicy", env, verbose=1)
model.learn(5000)
model.save("qrdqn_cartpole")

TRPO
----

Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment.

.. code-block:: python

from sb3_contrib import TRPO

model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
model.learn(total_timesteps=100_000, log_interval=4)
model.save("trpo_pendulum")
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/tqc
modules/trpo
modules/qrdqn
modules/ppo_mask

.. toctree::
:maxdepth: 1
:caption: Common

common/utils
common/wrappers

.. toctree::
Expand Down
12 changes: 6 additions & 6 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ Changelog
==========


Release 1.3.1a6 (WIP)
Release 1.3.1a7 (WIP)
-------------------------------
**Add TRPO**

Breaking Changes:
^^^^^^^^^^^^^^^^^
Expand All @@ -15,6 +16,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added ``TRPO`` (@cyprienc)
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)

Bug Fixes:
Expand All @@ -34,7 +36,7 @@ Documentation:
Release 1.3.0 (2021-10-23)
-------------------------------

**Invalid action masking for PPO**
**Add Invalid action masking for PPO**

.. warning::

Expand All @@ -52,6 +54,7 @@ New Features:
- Added ``MaskablePPO`` algorithm (@kronion)
- ``MaskablePPO`` Dictionary Observation support (@glmcdona)


Bug Fixes:
^^^^^^^^^^

Expand All @@ -75,9 +78,6 @@ Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.2.0

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)
Expand Down Expand Up @@ -221,4 +221,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
Contributors:
-------------

@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc
151 changes: 151 additions & 0 deletions docs/modules/trpo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
.. _tqc:

.. automodule:: sb3_contrib.trpo

TRPO
====

`Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
is an iterative approach for optimizing policies with guaranteed monotonic improvement.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
CnnPolicy
MultiInputPolicy


Notes
-----

- Original paper: https://arxiv.org/abs/1502.05477
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
Dict ❌ ✔️
============= ====== ===========


Example
-------

.. code-block:: python

import gym
import numpy as np

from sb3_contrib import TRPO

env = gym.make("Pendulum-v0")

model = TRPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("trpo_pendulum")

del model # remove to demonstrate saving and loading

model = TRPO.load("trpo_pendulum")

obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()


Results
-------

Result on the MuJoCo benchmark (1M steps on ``-v3`` envs with MuJoCo v2.1.0) using 3 seeds.
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/40>`_.


===================== ============
Environments TRPO
===================== ============
HalfCheetah 1803 +/- 46
Ant 3554 +/- 591
Hopper 3372 +/- 215
Walker2d 4502 +/- 234
Swimmer 359 +/- 2
===================== ============


How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone RL-Zoo and checkout the branch ``feat/trpo``:

.. code-block:: bash

git clone https://github.com/cyprienc/rl-baselines3-zoo
cd rl-baselines3-zoo/

Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):

.. code-block:: bash

python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000


Plot the results:

.. code-block:: bash

python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results
python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO


Parameters
----------

.. autoclass:: TRPO
:members:
:inherited-members:

.. _trpo_policies:

TRPO Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex:

.. autoclass:: MultiInputPolicy
:members:

.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy
:members:
:noindex:
1 change: 1 addition & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
Expand Down
96 changes: 95 additions & 1 deletion sb3_contrib/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import Callable, Optional, Sequence

import torch as th
from torch import nn


def quantile_huber_loss(
Expand Down Expand Up @@ -67,3 +68,96 @@ def quantile_huber_loss(
else:
loss = loss.mean()
return loss


def conjugate_gradient_solver(
matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor],
b,
max_iter=10,
residual_tol=1e-10,
) -> th.Tensor:
"""
Finds an approximate solution to a set of linear equations Ax = b

Sources:
- https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py
- https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122

Reference:
- https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6

:param matrix_vector_dot_fn:
a function that right multiplies a matrix A by a vector v
:param b:
the right hand term in the set of linear equations Ax = b
:param max_iter:
the maximum number of iterations (default is 10)
:param residual_tol:
residual tolerance for early stopping of the solving (default is 1e-10)
:return x:
the approximate solution to the system of equations defined by `matrix_vector_dot_fn`
and b
"""

# The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
# A small random gaussian noise is used for the initialization.
x = 1e-4 * th.randn_like(b)
residual = b - matrix_vector_dot_fn(x)
# Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared)
residual_squared_norm = th.matmul(residual, residual)

if residual_squared_norm < residual_tol:
# If the gradient becomes extremely small
# The denominator in alpha will become zero
# Leading to a division by zero
return x

p = residual.clone()

for i in range(max_iter):
# A @ p (matrix vector multiplication)
A_dot_p = matrix_vector_dot_fn(p)

alpha = residual_squared_norm / p.dot(A_dot_p)
x += alpha * p

if i == max_iter - 1:
return x

residual -= alpha * A_dot_p
new_residual_squared_norm = th.matmul(residual, residual)

if new_residual_squared_norm < residual_tol:
return x

beta = new_residual_squared_norm / residual_squared_norm
residual_squared_norm = new_residual_squared_norm
p = residual + beta * p


def flat_grad(
output,
parameters: Sequence[nn.parameter.Parameter],
create_graph: bool = False,
retain_graph: bool = False,
) -> th.Tensor:
"""
Returns the gradients of the passed sequence of parameters into a flat gradient.
Order of parameters is preserved.

:param output: functional output to compute the gradient for
:param parameters: sequence of ``Parameter``
:param retain_graph: – If ``False``, the graph used to compute the grad will be freed.
Defaults to the value of ``create_graph``.
:param create_graph: – If ``True``, graph of the derivative will be constructed,
allowing to compute higher order derivative products. Default: ``False``.
:return: Tensor containing the flattened gradients
"""
grads = th.autograd.grad(
output,
parameters,
create_graph=create_graph,
retain_graph=retain_graph,
allow_unused=True,
)
return th.cat([th.ravel(grad) for grad in grads if grad is not None])
2 changes: 2 additions & 0 deletions sb3_contrib/trpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.trpo.trpo import TRPO
16 changes: 16 additions & 0 deletions sb3_contrib/trpo/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for TRPO
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)

MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy

register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)
Loading