Skip to content

Commit 59be198

Browse files
cypriencaraffin
andauthored
Add Trust Region Policy Optimization (TRPO) (#40)
* Feat: adding TRPO algorithm (WIP) WIP - Trust Region Policy Algorithm Currently the Hessian vector product is not working (see inline comments for more detail) * Feat: adding TRPO algorithm (WIP) Adding no_grad block for the line search Additional assert in the conjugate solver to help debugging * Feat: adding TRPO algorithm (WIP) - Adding ActorCriticPolicy.get_distribution - Using the Distribution object to compute the KL divergence - Checking for objective improvement in the line search - Moving magic numbers to instance variables * Feat: adding TRPO algorithm (WIP) Improving numerical stability of the conjugate gradient algorithm Critic updates * Feat: adding TRPO algorithm (WIP) Changes around the alpha of the line search Adding TRPO to __init__ files * feat: TRPO - addressing PR comments - renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to matrix_vector_dot_func + docstring - extra comments + better variable names in trpo.py - defining a method for the hessian vector product instead of an inline function - fix registering correct policies for TRPO and using correct policy base in constructor * refactor: TRPO - policier - refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3 * feat: using updated ActorCriticPolicy from SB3 - get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this * Bump version for `get_distribution` support * Add basic test * Reformat * [ci skip] Fix changelog * fix: setting train mode for trpo * fix: batch_size type hint in trpo.py * style: renaming variables + docstring in trpo.py * Rename + cleanup * Move grad computation to separate method * Remove grad norm clipping * Remove n epochs and add sub-sampling * Update defaults * Add Doc * Add more test and fixes for CNN * Update doc + add benchmark * Add tests + update doc * Fix doc * Improve names for conjugate gradient * Update comments * Update changelog Co-authored-by: Antonin Raffin <[email protected]>
1 parent b44689b commit 59be198

File tree

19 files changed

+809
-27
lines changed

19 files changed

+809
-27
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ See documentation for the full list of included features.
2828
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
2929
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
3030
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
31+
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
3132

3233
**Gym Wrappers**:
3334
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)

docs/common/utils.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.. _utils:
2+
3+
Utils
4+
=====
5+
6+
.. automodule:: sb3_contrib.common.utils
7+
:members:

docs/guide/algos.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ along with some useful characteristics: support for discrete/continuous actions,
99
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
1010
============ =========== ============ ================= =============== ================
1111
TQC ✔️ ❌ ❌ ❌ ✔️
12-
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
12+
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
13+
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
1314
============ =========== ============ ================= =============== ================
1415

1516

docs/guide/examples.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,16 @@ Train a PPO with invalid action masking agent on a toy environment.
4444
model = MaskablePPO("MlpPolicy", env, verbose=1)
4545
model.learn(5000)
4646
model.save("qrdqn_cartpole")
47+
48+
TRPO
49+
----
50+
51+
Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment.
52+
53+
.. code-block:: python
54+
55+
from sb3_contrib import TRPO
56+
57+
model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1)
58+
model.learn(total_timesteps=100_000, log_interval=4)
59+
model.save("trpo_pendulum")

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
3232
:caption: RL Algorithms
3333

3434
modules/tqc
35+
modules/trpo
3536
modules/qrdqn
3637
modules/ppo_mask
3738

3839
.. toctree::
3940
:maxdepth: 1
4041
:caption: Common
4142

43+
common/utils
4244
common/wrappers
4345

4446
.. toctree::

docs/misc/changelog.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ Changelog
44
==========
55

66

7-
Release 1.3.1a6 (WIP)
7+
Release 1.3.1a7 (WIP)
88
-------------------------------
9+
**Add TRPO**
910

1011
Breaking Changes:
1112
^^^^^^^^^^^^^^^^^
@@ -15,6 +16,7 @@ Breaking Changes:
1516

1617
New Features:
1718
^^^^^^^^^^^^^
19+
- Added ``TRPO`` (@cyprienc)
1820
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
1921

2022
Bug Fixes:
@@ -34,7 +36,7 @@ Documentation:
3436
Release 1.3.0 (2021-10-23)
3537
-------------------------------
3638

37-
**Invalid action masking for PPO**
39+
**Add Invalid action masking for PPO**
3840

3941
.. warning::
4042

@@ -52,6 +54,7 @@ New Features:
5254
- Added ``MaskablePPO`` algorithm (@kronion)
5355
- ``MaskablePPO`` Dictionary Observation support (@glmcdona)
5456

57+
5558
Bug Fixes:
5659
^^^^^^^^^^
5760

@@ -75,9 +78,6 @@ Breaking Changes:
7578
^^^^^^^^^^^^^^^^^
7679
- Upgraded to Stable-Baselines3 >= 1.2.0
7780

78-
New Features:
79-
^^^^^^^^^^^^^
80-
8181
Bug Fixes:
8282
^^^^^^^^^^
8383
- QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright)
@@ -221,4 +221,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
221221
Contributors:
222222
-------------
223223

224-
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona
224+
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc

docs/modules/trpo.rst

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
.. _tqc:
2+
3+
.. automodule:: sb3_contrib.trpo
4+
5+
TRPO
6+
====
7+
8+
`Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
9+
is an iterative approach for optimizing policies with guaranteed monotonic improvement.
10+
11+
.. rubric:: Available Policies
12+
13+
.. autosummary::
14+
:nosignatures:
15+
16+
MlpPolicy
17+
CnnPolicy
18+
MultiInputPolicy
19+
20+
21+
Notes
22+
-----
23+
24+
- Original paper: https://arxiv.org/abs/1502.05477
25+
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
26+
27+
28+
Can I use?
29+
----------
30+
31+
- Recurrent policies: ❌
32+
- Multi processing: ✔️
33+
- Gym spaces:
34+
35+
36+
============= ====== ===========
37+
Space Action Observation
38+
============= ====== ===========
39+
Discrete ✔️ ✔️
40+
Box ✔️ ✔️
41+
MultiDiscrete ✔️ ✔️
42+
MultiBinary ✔️ ✔️
43+
Dict ❌ ✔️
44+
============= ====== ===========
45+
46+
47+
Example
48+
-------
49+
50+
.. code-block:: python
51+
52+
import gym
53+
import numpy as np
54+
55+
from sb3_contrib import TRPO
56+
57+
env = gym.make("Pendulum-v0")
58+
59+
model = TRPO("MlpPolicy", env, verbose=1)
60+
model.learn(total_timesteps=10000, log_interval=4)
61+
model.save("trpo_pendulum")
62+
63+
del model # remove to demonstrate saving and loading
64+
65+
model = TRPO.load("trpo_pendulum")
66+
67+
obs = env.reset()
68+
while True:
69+
action, _states = model.predict(obs, deterministic=True)
70+
obs, reward, done, info = env.step(action)
71+
env.render()
72+
if done:
73+
obs = env.reset()
74+
75+
76+
Results
77+
-------
78+
79+
Result on the MuJoCo benchmark (1M steps on ``-v3`` envs with MuJoCo v2.1.0) using 3 seeds.
80+
The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/40>`_.
81+
82+
83+
===================== ============
84+
Environments TRPO
85+
===================== ============
86+
HalfCheetah 1803 +/- 46
87+
Ant 3554 +/- 591
88+
Hopper 3372 +/- 215
89+
Walker2d 4502 +/- 234
90+
Swimmer 359 +/- 2
91+
===================== ============
92+
93+
94+
How to replicate the results?
95+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
96+
97+
Clone RL-Zoo and checkout the branch ``feat/trpo``:
98+
99+
.. code-block:: bash
100+
101+
git clone https://github.com/cyprienc/rl-baselines3-zoo
102+
cd rl-baselines3-zoo/
103+
104+
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
105+
106+
.. code-block:: bash
107+
108+
python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000
109+
110+
111+
Plot the results:
112+
113+
.. code-block:: bash
114+
115+
python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results
116+
python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO
117+
118+
119+
Parameters
120+
----------
121+
122+
.. autoclass:: TRPO
123+
:members:
124+
:inherited-members:
125+
126+
.. _trpo_policies:
127+
128+
TRPO Policies
129+
-------------
130+
131+
.. autoclass:: MlpPolicy
132+
:members:
133+
:inherited-members:
134+
135+
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
136+
:members:
137+
:noindex:
138+
139+
.. autoclass:: CnnPolicy
140+
:members:
141+
142+
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
143+
:members:
144+
:noindex:
145+
146+
.. autoclass:: MultiInputPolicy
147+
:members:
148+
149+
.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy
150+
:members:
151+
:noindex:

sb3_contrib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sb3_contrib.ppo_mask import MaskablePPO
44
from sb3_contrib.qrdqn import QRDQN
55
from sb3_contrib.tqc import TQC
6+
from sb3_contrib.trpo import TRPO
67

78
# Read version from file
89
version_file = os.path.join(os.path.dirname(__file__), "version.txt")

sb3_contrib/common/utils.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Optional
1+
from typing import Callable, Optional, Sequence
22

33
import torch as th
4+
from torch import nn
45

56

67
def quantile_huber_loss(
@@ -67,3 +68,96 @@ def quantile_huber_loss(
6768
else:
6869
loss = loss.mean()
6970
return loss
71+
72+
73+
def conjugate_gradient_solver(
74+
matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor],
75+
b,
76+
max_iter=10,
77+
residual_tol=1e-10,
78+
) -> th.Tensor:
79+
"""
80+
Finds an approximate solution to a set of linear equations Ax = b
81+
82+
Sources:
83+
- https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py
84+
- https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122
85+
86+
Reference:
87+
- https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6
88+
89+
:param matrix_vector_dot_fn:
90+
a function that right multiplies a matrix A by a vector v
91+
:param b:
92+
the right hand term in the set of linear equations Ax = b
93+
:param max_iter:
94+
the maximum number of iterations (default is 10)
95+
:param residual_tol:
96+
residual tolerance for early stopping of the solving (default is 1e-10)
97+
:return x:
98+
the approximate solution to the system of equations defined by `matrix_vector_dot_fn`
99+
and b
100+
"""
101+
102+
# The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
103+
# A small random gaussian noise is used for the initialization.
104+
x = 1e-4 * th.randn_like(b)
105+
residual = b - matrix_vector_dot_fn(x)
106+
# Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared)
107+
residual_squared_norm = th.matmul(residual, residual)
108+
109+
if residual_squared_norm < residual_tol:
110+
# If the gradient becomes extremely small
111+
# The denominator in alpha will become zero
112+
# Leading to a division by zero
113+
return x
114+
115+
p = residual.clone()
116+
117+
for i in range(max_iter):
118+
# A @ p (matrix vector multiplication)
119+
A_dot_p = matrix_vector_dot_fn(p)
120+
121+
alpha = residual_squared_norm / p.dot(A_dot_p)
122+
x += alpha * p
123+
124+
if i == max_iter - 1:
125+
return x
126+
127+
residual -= alpha * A_dot_p
128+
new_residual_squared_norm = th.matmul(residual, residual)
129+
130+
if new_residual_squared_norm < residual_tol:
131+
return x
132+
133+
beta = new_residual_squared_norm / residual_squared_norm
134+
residual_squared_norm = new_residual_squared_norm
135+
p = residual + beta * p
136+
137+
138+
def flat_grad(
139+
output,
140+
parameters: Sequence[nn.parameter.Parameter],
141+
create_graph: bool = False,
142+
retain_graph: bool = False,
143+
) -> th.Tensor:
144+
"""
145+
Returns the gradients of the passed sequence of parameters into a flat gradient.
146+
Order of parameters is preserved.
147+
148+
:param output: functional output to compute the gradient for
149+
:param parameters: sequence of ``Parameter``
150+
:param retain_graph: – If ``False``, the graph used to compute the grad will be freed.
151+
Defaults to the value of ``create_graph``.
152+
:param create_graph: – If ``True``, graph of the derivative will be constructed,
153+
allowing to compute higher order derivative products. Default: ``False``.
154+
:return: Tensor containing the flattened gradients
155+
"""
156+
grads = th.autograd.grad(
157+
output,
158+
parameters,
159+
create_graph=create_graph,
160+
retain_graph=retain_graph,
161+
allow_unused=True,
162+
)
163+
return th.cat([th.ravel(grad) for grad in grads if grad is not None])

sb3_contrib/trpo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
2+
from sb3_contrib.trpo.trpo import TRPO

0 commit comments

Comments
 (0)