Skip to content

Commit 652a6d0

Browse files
araffinAdamGleaveJacopoPanjkterry1Miffyli
authored
Refactor HER (#351)
* Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <[email protected]> * Update docs/index.rst Co-authored-by: Adam Gleave <[email protected]> * Update wording for RL zoo Co-authored-by: Adam Gleave <[email protected]> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <[email protected]> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <[email protected]> * Apply suggestions from code review Co-authored-by: Adam Gleave <[email protected]> Co-authored-by: Anssi <[email protected]> Co-authored-by: Adam Gleave <[email protected]> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <[email protected]> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <[email protected]> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <[email protected]> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <[email protected]> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <[email protected]> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <[email protected]> Co-authored-by: Jacopo Panerati <[email protected]> Co-authored-by: Justin Terry <[email protected]> Co-authored-by: Anssi <[email protected]> Co-authored-by: Tom Dörr <[email protected]> Co-authored-by: Tom Dörr <[email protected]> Co-authored-by: Costa Huang <[email protected]>
1 parent e945ec1 commit 652a6d0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1786
-1026
lines changed

CODE_OF_CONDUCT.md

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Contributor Covenant Code of Conduct
2+
3+
## Our Pledge
4+
5+
We as members, contributors, and leaders pledge to make participation in our
6+
community a harassment-free experience for everyone, regardless of age, body
7+
size, visible or invisible disability, ethnicity, sex characteristics, gender
8+
identity and expression, level of experience, education, socio-economic status,
9+
nationality, personal appearance, race, religion, or sexual identity
10+
and orientation.
11+
12+
We pledge to act and interact in ways that contribute to an open, welcoming,
13+
diverse, inclusive, and healthy community.
14+
15+
## Our Standards
16+
17+
Examples of behavior that contributes to a positive environment for our
18+
community include:
19+
20+
* Demonstrating empathy and kindness toward other people
21+
* Being respectful of differing opinions, viewpoints, and experiences
22+
* Giving and gracefully accepting constructive feedback
23+
* Accepting responsibility and apologizing to those affected by our mistakes,
24+
and learning from the experience
25+
* Focusing on what is best not just for us as individuals, but for the
26+
overall community
27+
28+
Examples of unacceptable behavior include:
29+
30+
* The use of sexualized language or imagery, and sexual attention or
31+
advances of any kind
32+
* Trolling, insulting or derogatory comments, and personal or political attacks
33+
* Public or private harassment
34+
* Publishing others' private information, such as a physical or email
35+
address, without their explicit permission
36+
* Other conduct which could reasonably be considered inappropriate in a
37+
professional setting
38+
39+
## Enforcement Responsibilities
40+
41+
Community leaders are responsible for clarifying and enforcing our standards of
42+
acceptable behavior and will take appropriate and fair corrective action in
43+
response to any behavior that they deem inappropriate, threatening, offensive,
44+
or harmful.
45+
46+
Community leaders have the right and responsibility to remove, edit, or reject
47+
comments, commits, code, wiki edits, issues, and other contributions that are
48+
not aligned to this Code of Conduct, and will communicate reasons for moderation
49+
decisions when appropriate.
50+
51+
## Scope
52+
53+
This Code of Conduct applies within all community spaces, and also applies when
54+
an individual is officially representing the community in public spaces.
55+
Examples of representing our community include using an official e-mail address,
56+
posting via an official social media account, or acting as an appointed
57+
representative at an online or offline event.
58+
59+
## Enforcement
60+
61+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
62+
reported to the community leaders responsible for enforcement at
63+
antonin [dot] raffin [at] dlr [dot] de.
64+
All complaints will be reviewed and investigated promptly and fairly.
65+
66+
All community leaders are obligated to respect the privacy and security of the
67+
reporter of any incident.
68+
69+
## Enforcement Guidelines
70+
71+
Community leaders will follow these Community Impact Guidelines in determining
72+
the consequences for any action they deem in violation of this Code of Conduct:
73+
74+
### 1. Correction
75+
76+
**Community Impact**: Use of inappropriate language or other behavior deemed
77+
unprofessional or unwelcome in the community.
78+
79+
**Consequence**: A private, written warning from community leaders, providing
80+
clarity around the nature of the violation and an explanation of why the
81+
behavior was inappropriate. A public apology may be requested.
82+
83+
### 2. Warning
84+
85+
**Community Impact**: A violation through a single incident or series
86+
of actions.
87+
88+
**Consequence**: A warning with consequences for continued behavior. No
89+
interaction with the people involved, including unsolicited interaction with
90+
those enforcing the Code of Conduct, for a specified period of time. This
91+
includes avoiding interactions in community spaces as well as external channels
92+
like social media. Violating these terms may lead to a temporary or
93+
permanent ban.
94+
95+
### 3. Temporary Ban
96+
97+
**Community Impact**: A serious violation of community standards, including
98+
sustained inappropriate behavior.
99+
100+
**Consequence**: A temporary ban from any sort of interaction or public
101+
communication with the community for a specified period of time. No public or
102+
private interaction with the people involved, including unsolicited interaction
103+
with those enforcing the Code of Conduct, is allowed during this period.
104+
Violating these terms may lead to a permanent ban.
105+
106+
### 4. Permanent Ban
107+
108+
**Community Impact**: Demonstrating a pattern of violation of community
109+
standards, including sustained inappropriate behavior, harassment of an
110+
individual, or aggression toward or disparagement of classes of individuals.
111+
112+
**Consequence**: A permanent ban from any sort of public interaction within
113+
the community.
114+
115+
## Attribution
116+
117+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118+
version 2.0, available at
119+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120+
121+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
122+
enforcement ladder](https://github.com/mozilla/diversity).
123+
124+
[homepage]: https://www.contributor-covenant.org
125+
126+
For answers to common questions about this code of conduct, see the FAQ at
127+
https://www.contributor-covenant.org/faq. Translations are available at
128+
https://www.contributor-covenant.org/translations.

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin
3737
| Type hints | :heavy_check_mark: |
3838

3939

40-
### Planned features (v1.1+)
40+
### Planned features
4141

4242
Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
4343

@@ -49,11 +49,13 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st
4949

5050
Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)
5151

52-
## RL Baselines3 Zoo: A Collection of Trained RL Agents
52+
## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents
5353

54-
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3.
54+
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a training framework for Reinforcement Learning (RL).
5555

56-
It also provides basic scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
56+
It provides scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.
57+
58+
In addition, it includes a collection of tuned hyperparameters for common environments and RL algorithms, and agents trained with those settings.
5759

5860
Goals of this repository:
5961

@@ -92,6 +94,7 @@ Install the Stable Baselines3 package:
9294
```
9395
pip install stable-baselines3[extra]
9496
```
97+
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).
9598

9699
This includes an optional dependencies like Tensorboard, OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
97100
```
@@ -111,9 +114,9 @@ import gym
111114

112115
from stable_baselines3 import PPO
113116

114-
env = gym.make('CartPole-v1')
117+
env = gym.make("CartPole-v1")
115118

116-
model = PPO('MlpPolicy', env, verbose=1)
119+
model = PPO("MlpPolicy", env, verbose=1)
117120
model.learn(total_timesteps=10000)
118121

119122
obs = env.reset()

docs/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ This folder contains documentation for the RL baselines.
66
### Build the Documentation
77

88
#### Install Sphinx and Theme
9-
9+
Execute this command in the project root:
1010
```
11-
pip install sphinx sphinx-autobuild sphinx-rtd-theme
11+
pip install -e .[docs]
1212
```
1313

1414
#### Building the Docs

docs/_static/img/net_arch.png

135 KB
Loading

docs/_static/img/sb3_loop.png

165 KB
Loading

docs/_static/img/sb3_policy.png

176 KB
Loading

docs/guide/callbacks.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ It will save the best model if ``best_model_save_path`` folder is specified and
185185
You can pass a child callback via the ``callback_on_new_best`` argument. It will be triggered each time there is a new best model.
186186

187187

188+
.. warning::
189+
190+
You need to make sure that ``eval_env`` is wrapped the same way as the training environment, for instance using the ``VecTransposeImage`` wrapper if you have a channel-last image as input.
191+
The ``EvalCallback`` class outputs a warning if it is not the case.
192+
188193

189194
.. code-block:: python
190195

docs/guide/custom_env.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ That is to say, your environment must implement the following methods (and inher
1313
channel-first or channel-last.
1414

1515

16+
.. note::
17+
18+
Although SB3 supports both channel-last and channel-first images as input, we recommend using the channel-first convention when possible.
19+
Under the hood, when a channel-last image is passed, SB3 uses a ``VecTransposeImage`` wrapper to re-order the channels.
20+
21+
1622

1723
.. code-block:: python
1824
@@ -29,9 +35,9 @@ That is to say, your environment must implement the following methods (and inher
2935
# They must be gym.spaces objects
3036
# Example when using discrete actions:
3137
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
32-
# Example for using image as input (can be channel-first or channel-last):
38+
# Example for using image as input (channel-first; channel-last also works):
3339
self.observation_space = spaces.Box(low=0, high=255,
34-
shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
40+
shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)
3541
3642
def step(self, action):
3743
...

docs/guide/custom_policy.rst

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
Custom Policy Network
44
=====================
55

6-
Stable Baselines3 provides policy networks for images (CnnPolicies)
7-
and other type of input features (MlpPolicies).
6+
Stable Baselines3 provides policy networks for images (CnnPolicies),
7+
other type of input features (MlpPolicies) and multiple different inputs (MultiInputPolicies).
88

99

1010
.. warning::
@@ -13,9 +13,49 @@ and other type of input features (MlpPolicies).
1313
which handles bounds more correctly.
1414

1515

16+
SB3 Policy
17+
^^^^^^^^^^
1618

17-
Custom Policy Architecture
18-
^^^^^^^^^^^^^^^^^^^^^^^^^^
19+
SB3 networks are separated into two mains parts (see figure below):
20+
21+
- A features extractor (usually shared between actor and critic when applicable, to save computation)
22+
whose role is to extract features (i.e. convert to a feature vector) from high-dimensional observations, for instance, a CNN that extracts features from images.
23+
This is the ``features_extractor_class`` parameter. You can change the default parameters of that features extractor
24+
by passing a ``features_extractor_kwargs`` parameter.
25+
26+
- A (fully-connected) network that maps the features to actions/value. Its architecture is controlled by the ``net_arch`` parameter.
27+
28+
29+
.. note::
30+
31+
All observations are first pre-processed (e.g. images are normalized, discrete obs are converted to one-hot vectors, ...) before being fed to the features extractor.
32+
In the case of vector observations, the features extractor is just a ``Flatten`` layer.
33+
34+
35+
.. image:: ../_static/img/net_arch.png
36+
37+
38+
SB3 policies are usually composed of several networks (actor/critic networks + target networks when applicable) together
39+
with the associated optimizers.
40+
41+
Each of these network have a features extractor followed by a fully-connected network.
42+
43+
.. note::
44+
45+
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
46+
In SB3, "policy" refers to the class that handles all the networks useful for training,
47+
so not only the network used to predict actions (the "learned controller").
48+
49+
50+
51+
.. image:: ../_static/img/sb3_policy.png
52+
53+
54+
.. .. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
55+
56+
57+
Custom Network Architecture
58+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
1959

2060
One way of customising the policy network architecture is to pass arguments when creating the model,
2161
using ``policy_kwargs`` parameter:
@@ -109,6 +149,70 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
109149
model.learn(1000)
110150
111151
152+
Multiple Inputs and Dictionary Observations
153+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
154+
155+
Stable Baselines3 supports handling of multiple inputs by using ``Dict`` Gym space. This can be done using
156+
``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` feature extractor to turn multiple
157+
inputs into a single vector, handled by the ``net_arch`` network.
158+
159+
By default, ``CombinedExtractor`` processes multiple inputs as follows:
160+
161+
1. If input is an image (automatically detected, see ``common.preprocessing.is_image_space``), process image with Nature Atari CNN network and
162+
output a latent vector of size ``64``.
163+
2. If input is not an image, flatten it (no layers).
164+
3. Concatenate all previous vectors into one long vector and pass it to policy.
165+
166+
Much like above, you can define custom feature extractors as above. The following example assumes the environment has two keys in the
167+
observation space dictionary: "image" is a (1,H,W) image, and "vector" is a (D,) dimensional vector. We process "image" with a simple
168+
downsampling and "vector" with a single linear layer.
169+
170+
.. code-block:: python
171+
172+
import gym
173+
import torch as th
174+
from torch import nn
175+
176+
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
177+
178+
class CustomCombinedExtractor(BaseFeaturesExtractor):
179+
def __init__(self, observation_space: gym.spaces.Dict):
180+
# We do not know features-dim here before going over all the items,
181+
# so put something dummy for now. PyTorch requires calling
182+
# nn.Module.__init__ before adding modules
183+
super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)
184+
185+
extractors = {}
186+
187+
total_concat_size = 0
188+
# We need to know size of the output of this extractor,
189+
# so go over all the spaces and compute output feature sizes
190+
for key, subspace in observation_space.spaces.items():
191+
if key == "image":
192+
# We will just downsample one channel of the image by 4x4 and flatten.
193+
# Assume the image is single-channel (subspace.shape[0] == 0)
194+
extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
195+
total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4
196+
elif key == "vector":
197+
# Run through a simple MLP
198+
extractors[key] = nn.Linear(subspace.shape[0], 16)
199+
total_concat_size += 16
200+
201+
self.extractors = nn.ModuleDict(extractors)
202+
203+
# Update the features dim manually
204+
self._features_dim = total_concat_size
205+
206+
def forward(self, observations) -> th.Tensor:
207+
encoded_tensor_list = []
208+
209+
# self.extractors contain nn.Modules that do all the processing.
210+
for key, extractor in self.extractors.items():
211+
encoded_tensor_list.append(extractor(observations[key]))
212+
# Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
213+
return th.cat(encoded_tensor_list, dim=1)
214+
215+
112216
113217
On-Policy Algorithms
114218
^^^^^^^^^^^^^^^^^^^^

docs/guide/developer.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ Each algorithm has two main methods:
3131
- ``.train()`` which updates the parameters using samples from the buffer
3232

3333

34+
.. image:: ../_static/img/sb3_loop.png
35+
36+
3437
Where to start?
3538
===============
3639

0 commit comments

Comments
 (0)