Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f779a9f
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 7, 2021
98bc5b2
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 9, 2021
97ece67
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 17, 2021
799b140
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 17, 2021
dc73462
Feat: adding TRPO algorithm (WIP)
cyprienc Aug 19, 2021
9b8a222
feat: TRPO - addressing PR comments
cyprienc Sep 11, 2021
869dce9
refactor: TRPO - policier
cyprienc Sep 11, 2021
347dcc0
feat: using updated ActorCriticPolicy from SB3
cyprienc Sep 11, 2021
35d7256
Bump version for `get_distribution` support
araffin Sep 13, 2021
9cfcb54
Add basic test
araffin Sep 13, 2021
974174a
Reformat
araffin Sep 13, 2021
b6bd449
[ci skip] Fix changelog
araffin Sep 13, 2021
c88951c
fix: setting train mode for trpo
cyprienc Sep 13, 2021
1f7e99d
fix: batch_size type hint in trpo.py
cyprienc Sep 13, 2021
6540371
style: renaming variables + docstring in trpo.py
cyprienc Sep 15, 2021
3a26c05
Merge branch 'master' into master
araffin Sep 23, 2021
f003e88
Merge branch 'master' into master
araffin Sep 27, 2021
a33409e
Merge branch 'master' into master
araffin Sep 29, 2021
8ecf40e
Rename + cleanup
araffin Sep 29, 2021
45f4ea6
Move grad computation to separate method
araffin Sep 29, 2021
cc4b5ab
Remove grad norm clipping
araffin Sep 29, 2021
fc7a6c7
Remove n epochs and add sub-sampling
araffin Sep 29, 2021
66723ff
Update defaults
araffin Sep 29, 2021
63a263f
Merge branch 'master' into master
araffin Dec 1, 2021
bf583de
Merge branch 'master' into cyprienc/master
araffin Dec 10, 2021
e983348
Add Doc
araffin Dec 27, 2021
439d79b
Add more test and fixes for CNN
araffin Dec 27, 2021
d9483dc
Update doc + add benchmark
araffin Dec 28, 2021
fff84e4
Add tests + update doc
araffin Dec 28, 2021
95dddf4
Fix doc
araffin Dec 28, 2021
661fe15
Improve names for conjugate gradient
araffin Dec 29, 2021
a24e7c0
Update comments
araffin Dec 29, 2021
342fe53
Update changelog
araffin Dec 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO


# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
Expand Down
28 changes: 28 additions & 0 deletions sb3_contrib/common/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Policies: abstract base class and concrete implementations."""

from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import ActorCriticPolicy as _ActorCriticPolicy


class ActorCriticPolicy(_ActorCriticPolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
"""

def get_distribution(self) -> Distribution:
"""
Get the current action distribution
:return: Action distribution
"""
return self.action_dist


# This is just to propagate get_distribution
class ActorCriticCnnPolicy(ActorCriticPolicy):
pass


# This is just to propagate get_distribution
class MultiInputActorCriticPolicy(ActorCriticPolicy):
pass
91 changes: 90 additions & 1 deletion sb3_contrib/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import Optional, Sequence, Callable

import torch as th
from torch import nn


def quantile_huber_loss(
Expand Down Expand Up @@ -67,3 +68,91 @@ def quantile_huber_loss(
else:
loss = loss.mean()
return loss


# TODO: write regression tests
def conjugate_gradient_solver(
matrix_vector_dot_func: Callable[[th.Tensor], th.Tensor],
b,
max_iter=10,
residual_tol=1e-10,
) -> th.Tensor:
"""
Finds an approximate solution to a set of linear equations Ax = b

Source: https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py

:param matrix_vector_dot_func:
a function that right multiplies a matrix A by a vector v
:param b:
the right hand term in the set of linear equations Ax = b
:param max_iter:
the maximum number of iterations (default is 10)
:param residual_tol:
residual tolerance for early stopping of the solving (default is 1e-10)
:return x:
the approximate solution to the system of equations defined by Avp_fun
and b
"""

# The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
# A small random gaussian noise is used for the initialization.
x = 1e-4 * th.randn_like(b)
r = b - matrix_vector_dot_func(x)
r_dot = th.matmul(r, r)

if r_dot < residual_tol:
# If the gradient becomes extremely small
# The denominator in alpha will become zero
# Leading to a division by zero
return x

p = r.clone()

for i in range(max_iter):
Avp = matrix_vector_dot_func(p)

alpha = r_dot / p.dot(Avp)
x += alpha * p

if i == max_iter - 1:
return x

r -= alpha * Avp
new_r_dot = th.matmul(r, r)

if new_r_dot < residual_tol:
return x

beta = new_r_dot / r_dot
r_dot = new_r_dot
p = r + beta * p


# TODO: test
def flat_grad(
output,
parameters: Sequence[nn.parameter.Parameter],
create_graph: bool = False,
retain_graph: bool = False,
) -> th.Tensor:
"""
Returns the gradients of the passed sequence of parameters into a flat gradient.
Order of parameters is preserved.

:param output: functional output to compute the gradient for
:param parameters: sequence of `Parameter`
:param retain_graph – If ``False``, the graph used to compute the grad will be freed.
Defaults to the value of ``create_graph``.
:param create_graph – If ``True``, graph of the derivative will be constructed,
allowing to compute higher order derivative products. Default: ``False``.
:return: Tensor containing the flattened gradients
"""
grads = th.autograd.grad(
output,
parameters,
create_graph=create_graph,
retain_graph=retain_graph,
allow_unused=True,
)
return th.cat([grad.view(-1) for grad in grads if grad is not None])
2 changes: 2 additions & 0 deletions sb3_contrib/trpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.trpo.trpo import TRPO
13 changes: 13 additions & 0 deletions sb3_contrib/trpo/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for TRPO
from sb3_contrib.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.policies import register_policy


MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy

register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)
Loading