Skip to content

Commit dd6e361

Browse files
Implement HER (#120)
* Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <[email protected]> Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 15e94a6 commit dd6e361

34 files changed

+1899
-102
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pytest:
55
./scripts/run_tests.sh
66

77
type:
8-
pytype
8+
pytype -j auto
99

1010
lint:
1111
# stop the build if there are Python syntax errors or undefined names

README.md

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,9 @@ These algorithms will make it easier for the research community and industry to
3535
| Type hints | :heavy_check_mark: |
3636

3737

38-
### Roadmap to V1.0
39-
40-
Please look at the issue for more details.
41-
Planned features:
42-
43-
- [ ] HER
44-
4538
### Planned features (v1.1+)
4639

47-
- [ ] DQN extensions (prioritized replay, double q-learning, ...)
48-
- [ ] Support for `Tuple` and `Dict` observation spaces
49-
- [ ] Recurrent Policies
50-
- [ ] TRPO
51-
40+
Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
5241

5342
## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)
5443

docs/guide/examples.rst

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ notebooks:
1818
- `Atari Games`_
1919
- `RL Baselines zoo`_
2020
- `PyBullet`_
21-
22-
.. - `Hindsight Experience Replay`_
21+
- `Hindsight Experience Replay`_
2322

2423
.. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb
2524
.. _Training, Saving, Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
@@ -343,6 +342,81 @@ will compute a running average and standard deviation of input features (it can
343342
env.norm_reward = False
344343
345344
345+
Hindsight Experience Replay (HER)
346+
---------------------------------
347+
348+
For this example, we are using `Highway-Env <https://github.com/eleurent/highway-env>`_ by `@eleurent <https://github.com/eleurent>`_.
349+
350+
351+
.. image:: ../_static/img/colab-badge.svg
352+
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
353+
354+
355+
.. figure:: https://raw.githubusercontent.com/eleurent/highway-env/gh-media/docs/media/parking-env.gif
356+
357+
The highway-parking-v0 environment.
358+
359+
The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading.
360+
361+
.. note::
362+
363+
The hyperparameters in the following example were optimized for that environment.
364+
365+
366+
.. code-block:: python
367+
368+
import gym
369+
import highway_env
370+
import numpy as np
371+
372+
from stable_baselines3 import HER, SAC, DDPG, TD3
373+
from stable_baselines3.common.noise import NormalActionNoise
374+
375+
env = gym.make("parking-v0")
376+
377+
# Create 4 artificial transitions per real transition
378+
n_sampled_goal = 4
379+
380+
# SAC hyperparams:
381+
model = HER(
382+
"MlpPolicy",
383+
env,
384+
SAC,
385+
n_sampled_goal=n_sampled_goal,
386+
goal_selection_strategy="future",
387+
# IMPORTANT: because the env is not wrapped with a TimeLimit wrapper
388+
# we have to manually specify the max number of steps per episode
389+
max_episode_length=100,
390+
verbose=1,
391+
buffer_size=int(1e6),
392+
learning_rate=1e-3,
393+
gamma=0.95,
394+
batch_size=256,
395+
online_sampling=True,
396+
policy_kwargs=dict(net_arch=[256, 256, 256]),
397+
)
398+
399+
model.learn(int(2e5))
400+
model.save("her_sac_highway")
401+
402+
# Load saved model
403+
model = HER.load("her_sac_highway", env=env)
404+
405+
obs = env.reset()
406+
407+
# Evaluate the agent
408+
episode_reward = 0
409+
for _ in range(100):
410+
action, _ = model.predict(obs, deterministic=True)
411+
obs, reward, done, info = env.step(action)
412+
env.render()
413+
episode_reward += reward
414+
if done or info.get("is_success", False):
415+
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
416+
episode_reward = 0.0
417+
obs = env.reset()
418+
419+
346420
Record a Video
347421
--------------
348422

docs/guide/migration.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ Despite this change, no change in performance should be expected.
163163
To match SB2 behavior, you need to explicitly pass ``deterministic=True``
164164

165165

166+
HER
167+
^^^
168+
169+
The ``HER`` implementation now also supports online sampling of the new goals. This is done in a vectorized version.
170+
The goal selection strategy ``RANDOM`` is no longer supported.
171+
``HER`` now supports ``VecNormalize`` wrapper but only when ``online_sampling=True``.
172+
For performance reasons, the maximum number of steps per episodes must be specified (see :ref:`HER <her>` documentation).
173+
166174

167175
New logger API
168176
--------------

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Main Features
5757
modules/a2c
5858
modules/ddpg
5959
modules/dqn
60+
modules/her
6061
modules/ppo
6162
modules/sac
6263
modules/td3

docs/misc/changelog.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Pre-Release 0.10.0a0 (WIP)
7+
Pre-Release 0.10.0a1 (WIP)
88
------------------------------
99

1010
Breaking Changes:
@@ -14,11 +14,14 @@ Breaking Changes:
1414
New Features:
1515
^^^^^^^^^^^^^
1616
- Allow custom actor/critic network architectures using ``net_arch=dict(qf=[400, 300], pi=[64, 64])`` for off-policy algorithms (SAC, TD3, DDPG)
17+
- Added Hindsight Experience Replay ``HER``. (@megan-klaiber)
18+
- ``VecNormalize`` now supports ``gym.spaces.Dict`` observation spaces
1719
- Support logging videos to Tensorboard (@SwamyDev)
1820

1921
Bug Fixes:
2022
^^^^^^^^^^
2123
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
24+
- Fixed potential issue when loading a different environment
2225
- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev)
2326
- Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR)
2427
- Fix model creation initializing CUDA even when `device="cpu"` is provided
@@ -37,6 +40,7 @@ Others:
3740
Documentation:
3841
^^^^^^^^^^^^^^
3942
- Added first draft of migration guide
43+
- Enabled doc for ``CnnPolicies``
4044

4145

4246
Pre-Release 0.9.0 (2020-10-03)
@@ -68,6 +72,7 @@ New Features:
6872

6973
Bug Fixes:
7074
^^^^^^^^^^
75+
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
7176
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
7277
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
7378
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
@@ -160,7 +165,6 @@ Documentation:
160165
- Fixed typo in custom policy doc (@RaphaelWag)
161166

162167

163-
164168
Pre-Release 0.7.0 (2020-06-10)
165169
------------------------------
166170

@@ -461,4 +465,4 @@ And all the contributors:
461465
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
462466
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
463467
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
464-
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88
468+
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber

docs/modules/a2c.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ It uses multiple workers to avoid the use of a replay buffer.
1111

1212

1313
.. warning::
14-
14+
1515
If you find training unstable or want to match performance of stable-baselines A2C, consider using
1616
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
1717
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``.
@@ -79,3 +79,22 @@ Parameters
7979
.. autoclass:: A2C
8080
:members:
8181
:inherited-members:
82+
83+
84+
A2C Policies
85+
-------------
86+
87+
.. autoclass:: MlpPolicy
88+
:members:
89+
:inherited-members:
90+
91+
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
92+
:members:
93+
:noindex:
94+
95+
.. autoclass:: CnnPolicy
96+
:members:
97+
98+
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
99+
:members:
100+
:noindex:

docs/modules/ddpg.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ DDPG Policies
9898
:members:
9999
:inherited-members:
100100

101+
.. autoclass:: stable_baselines3.td3.policies.TD3Policy
102+
:members:
103+
:noindex:
101104

102-
.. .. autoclass:: CnnPolicy
103-
.. :members:
104-
.. :inherited-members:
105+
.. autoclass:: CnnPolicy
106+
:members:

docs/modules/dqn.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,9 @@ DQN Policies
9090
:members:
9191
:inherited-members:
9292

93+
.. autoclass:: stable_baselines3.dqn.policies.DQNPolicy
94+
:members:
95+
:noindex:
96+
9397
.. autoclass:: CnnPolicy
9498
:members:

docs/modules/her.rst

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
.. _her:
2+
3+
.. automodule:: stable_baselines3.her
4+
5+
6+
HER
7+
====
8+
9+
`Hindsight Experience Replay (HER) <https://arxiv.org/abs/1707.01495>`_
10+
11+
HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example).
12+
HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout.
13+
It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes.
14+
15+
16+
17+
.. warning::
18+
19+
HER requires the environment to inherits from `gym.GoalEnv <https://github.com/openai/gym/blob/3394e245727c1ae6851b504a50ba77c73cd4c65b/gym/core.py#L160>`_
20+
21+
22+
.. warning::
23+
24+
For performance reasons, the maximum number of steps per episodes must be specified.
25+
In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment
26+
or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None).
27+
Otherwise, you can directly pass ``max_episode_length`` to the model constructor
28+
29+
30+
.. warning::
31+
32+
``HER`` supports ``VecNormalize`` wrapper but only when ``online_sampling=True``
33+
34+
35+
Notes
36+
-----
37+
38+
- Original paper: https://arxiv.org/abs/1707.01495
39+
- OpenAI paper: `Plappert et al. (2018)`_
40+
- OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/
41+
42+
43+
.. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464
44+
45+
Can I use?
46+
----------
47+
48+
Please refer to the used model (DQN, SAC, TD3 or DDPG) for that section.
49+
50+
Example
51+
-------
52+
53+
.. code-block:: python
54+
55+
from stable_baselines3 import HER, DDPG, DQN, SAC, TD3
56+
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
57+
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
58+
from stable_baselines3.common.vec_env import DummyVecEnv
59+
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
60+
61+
model_class = DQN # works also with SAC, DDPG and TD3
62+
N_BITS = 15
63+
64+
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
65+
66+
# Available strategies (cf paper): future, final, episode
67+
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
68+
69+
# If True the HER transitions will get sampled online
70+
online_sampling = True
71+
# Time limit for the episodes
72+
max_episode_length = N_BITS
73+
74+
# Initialize the model
75+
model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, online_sampling=online_sampling,
76+
verbose=1, max_episode_length=max_episode_length)
77+
# Train the model
78+
model.learn(1000)
79+
80+
model.save("./her_bit_env")
81+
model = HER.load('./her_bit_env', env=env)
82+
83+
obs = env.reset()
84+
for _ in range(100):
85+
action, _ = model.model.predict(obs, deterministic=True)
86+
obs, reward, done, _ = env.step(action)
87+
88+
if done:
89+
obs = env.reset()
90+
91+
92+
Parameters
93+
----------
94+
95+
.. autoclass:: HER
96+
:members:
97+
98+
Goal Selection Strategies
99+
-------------------------
100+
101+
.. autoclass:: GoalSelectionStrategy
102+
:members:
103+
:inherited-members:
104+
:undoc-members:
105+
106+
107+
Obs Dict Wrapper
108+
----------------
109+
110+
.. autoclass:: ObsDictWrapper
111+
:members:
112+
:inherited-members:
113+
:undoc-members:
114+
115+
116+
HER Replay Buffer
117+
-----------------
118+
119+
.. autoclass:: HerReplayBuffer
120+
:members:
121+
:inherited-members:

0 commit comments

Comments
 (0)