diff --git a/README.md b/README.md index 3415e4d091..7f928416e9 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,9 @@ These algorithms will make it easier for the research community and industry to ## Main Features +**The performance of each algorithm was tested** (see *Results* section in their respective page), +you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. + | **Features** | **Stable-Baselines3** | | --------------------------- | ----------------------| diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 7206c7c0e6..6588f90fba 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -15,7 +15,7 @@ To build a custom callback, you need to create a class that derives from ``BaseC This will give you access to events (``_on_training_start``, ``_on_step``) and useful variables (like `self.model` for the RL model). -.. You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see :ref:`Examples `), and one for logging additional values with Tensorboard (see :ref:`Tensorboard section `). +You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see :ref:`Examples `), and one for logging additional values with Tensorboard (see :ref:`Tensorboard section `). .. code-block:: python diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index b490bd7b14..8970cda0a8 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -7,6 +7,13 @@ Stable Baselines3 provides policy networks for images (CnnPolicies) and other type of input features (MlpPolicies). +.. warning:: + For A2C and PPO, continuous actions are clipped during training and testing + (to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a ``tanh()`` transformation, + which handles bounds more correctly. + + + Custom Policy Architecture ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 4b702b7454..2e9f2b2384 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -19,6 +19,7 @@ notebooks: - `RL Baselines zoo`_ - `PyBullet`_ - `Hindsight Experience Replay`_ +- `Advanced Saving and Loading`_ .. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb .. _Training, Saving, Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb @@ -28,6 +29,7 @@ notebooks: .. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb .. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb .. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb +.. _Advanced Saving and Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb .. |colab| image:: ../_static/img/colab.svg @@ -417,6 +419,171 @@ The parking env is a goal-conditioned continuous control task, in which the vehi obs = env.reset() +Advanced Saving and Loading +--------------------------------- + +In this example, we show how to use some advanced features of Stable-Baselines3 (SB3): +how to easily create a test environment to evaluate an agent periodically, +use a policy independently from a model (and how to save it, load it) and save/load a replay buffer. + +By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images). +However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately. + + +.. image:: ../_static/img/colab-badge.svg + :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb + +Stable-Baselines3 automatic creation of an environment for evaluation. +For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent. +Behind the scene, SB3 uses an :ref:`EvalCallback `. + +.. code-block:: python + + from stable_baselines3 import SAC + from stable_baselines3.common.evaluation import evaluate_policy + from stable_baselines3.sac.policies import MlpPolicy + + # Create the model, the training environment + # and the test environment (for evaluation) + model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1, + learning_rate=1e-3, create_eval_env=True) + + # Evaluate the model every 1000 steps on 5 test episodes + # and save the evaluation to the "logs/" folder + model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/") + + # save the model + model.save("sac_pendulum") + + # the saved model does not contain the replay buffer + loaded_model = SAC.load("sac_pendulum") + print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") + + # now save the replay buffer too + model.save_replay_buffer("sac_replay_buffer") + + # load it into the loaded_model + loaded_model.load_replay_buffer("sac_replay_buffer") + + # now the loaded replay is not empty anymore + print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") + + # Save the policy independently from the model + # Note: if you don't save the complete model with `model.save()` + # you cannot continue training afterward + policy = model.policy + policy.save("sac_policy_pendulum.pkl") + + # Retrieve the environment + env = model.get_env() + + # Evaluate the policy + mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True) + + print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") + + # Load the policy independently from the model + saved_policy = MlpPolicy.load("sac_policy_pendulum") + + # Evaluate the loaded policy + mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True) + + print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") + + + +Accessing and modifying model parameters +---------------------------------------- + +You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions, +or via ``model.policy.state_dict()`` (and ``load_state_dict()``), +which use dictionaries that map variable names to PyTorch tensors. + +These functions are useful when you need to e.g. evaluate large set of models with same network structure, +visualize different layers of the network or modify parameters manually. + +Policies also offers a simple way to save/load weights as a NumPy vector, using ``parameters_to_vector()`` +and ``load_from_vector()`` method. + +Following example demonstrates reading parameters, modifying some of them and loading them to model +by implementing `evolution strategy (es) `_ +for solving the ``CartPole-v1`` environment. The initial guess for parameters is obtained by running +A2C policy gradient updates on the model. + +.. code-block:: python + + from typing import Dict + + import gym + import numpy as np + import torch as th + + from stable_baselines3 import A2C + from stable_baselines3.common.evaluation import evaluate_policy + + + def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]: + """Mutate parameters by adding normal noise to them""" + return dict((name, param + th.randn_like(param)) for name, param in params.items()) + + + # Create policy with a small network + model = A2C( + "MlpPolicy", + "CartPole-v1", + ent_coef=0.0, + policy_kwargs={"net_arch": [32]}, + seed=0, + learning_rate=0.05, + ) + + # Use traditional actor-critic policy gradient updates to + # find good initial parameters + model.learn(total_timesteps=10000) + + # Include only variables with "policy", "action" (policy) or "shared_net" (shared layers) + # in their name: only these ones affect the action. + # NOTE: you can retrieve those parameters using model.get_parameters() too + mean_params = dict( + (key, value) + for key, value in model.policy.state_dict().items() + if ("policy" in key or "shared_net" in key or "action" in key) + ) + + # population size of 50 invdiduals + pop_size = 50 + # Keep top 10% + n_elite = pop_size // 10 + # Retrieve the environment + env = model.get_env() + + for iteration in range(10): + # Create population of candidates and evaluate them + population = [] + for population_i in range(pop_size): + candidate = mutate(mean_params) + # Load new policy parameters to agent. + # Tell function that it should only update parameters + # we give it (policy parameters) + model.policy.load_state_dict(candidate, strict=False) + # Evaluate the candidate + fitness, _ = evaluate_policy(model, env) + population.append((candidate, fitness)) + # Take top 10% and use average over their parameters as next mean parameter + top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite] + mean_params = dict( + ( + name, + th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0), + ) + for name in mean_params.keys() + ) + mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite + print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}") + print(f"Best fitness: {top_candidates[0][1]:.2f}") + + + Record a Video -------------- diff --git a/docs/guide/export.rst b/docs/guide/export.rst new file mode 100644 index 0000000000..8be495194b --- /dev/null +++ b/docs/guide/export.rst @@ -0,0 +1,67 @@ +.. _export: + + +Exporting models +================ + +After training an agent, you may want to deploy/use it in another language +or framework, like `tensorflowjs `_. +Stable Baselines3 does not include tools to export models to other frameworks, but +this document aims to cover parts that are required for exporting along with +more detailed stories from users of Stable Baselines3. + + +Background +---------- + +In Stable Baselines3, the controller is stored inside policies which convert +observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) +contains a policy object which represents the currently learned behavior, +accessible via ``model.policy``. + +Policies hold enough information to do the inference (i.e. predict actions), +so it is enough to export these policies (cf :ref:`examples `) +to do inference in another framework. + +.. warning:: + When using CNN policies, the observation is normalized during pre-preprocessing. + This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1]) + + +Export to ONNX +----------------- + +TODO: help is welcomed! + + +Export to C++ +----------------- + +(using PyTorch JIT) +TODO: help is welcomed! + + +Export to tensorflowjs / ONNX-JS +-------------------------------- + +TODO: contributors help is welcomed! +Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js + + + +Manual export +------------- + +You can also manually export required parameters (weights) and construct the +network in your desired framework. + +You can access parameters of the model via agents' +:func:`get_parameters ` function. +As policies are also PyTorch modules, you can also access ``model.policy.state_dict()`` directly. +To find the architecture of the networks for each algorithm, best is to check the ``policies.py`` file located +in their respective folders. + +.. note:: + + In most cases, we recommend using PyTorch methods ``state_dict()`` and ``load_state_dict()`` from the policy, + unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters()``. diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 0d54ad4c15..3bbee833e9 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -59,6 +59,7 @@ Moved Files - ``bench/monitor.py`` -> ``common/monitor.py`` - ``logger.py`` -> ``common/logger.py`` - ``results_plotter.py`` -> ``common/results_plotter.py`` +- ``common/cmd_util.py`` -> ``common/env_util.py`` Utility functions are no longer exported from ``common`` module, you should import them with their absolute path, e.g.: diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 2226799096..535d04a021 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -146,17 +146,17 @@ for continuous actions problems (cf *Bullet* envs). -.. Goal Environment -.. ----------------- -.. -.. If your environment follows the ``GoalEnv`` interface (cf `HER <../modules/her.html>`_), then you should use -.. HER + (SAC/TD3/DDPG/DQN) depending on the action space. -.. -.. -.. .. note:: -.. -.. The number of workers is an important hyperparameters for experiments with HER -.. +Goal Environment +----------------- + +If your environment follows the ``GoalEnv`` interface (cf :ref:`HER `), then you should use +HER + (SAC/TD3/DDPG/DQN) depending on the action space. + + +.. note:: + + The number of workers is an important hyperparameters for experiments with HER + Tips and Tricks when creating a custom environment diff --git a/docs/guide/save_format.rst b/docs/guide/save_format.rst new file mode 100644 index 0000000000..38dc233747 --- /dev/null +++ b/docs/guide/save_format.rst @@ -0,0 +1,59 @@ +.. _save_format: + + +On saving and loading +===================== + +Stable Baselines3 (SB3) stores both neural network parameters and algorithm-related parameters such as +exploration schedule, number of environments and observation/action space. This allows continual learning and easy +use of trained agents without training, but it is not without its issues. Following describes the format +used to save agents in SB3 along with its pros and shortcomings. + +Terminology used in this page: + +- *parameters* refer to neural network parameters (also called "weights"). This is a dictionary + mapping variable name to a PyTorch tensor. +- *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space. + These depend on the algorithm used. This is a dictionary mapping classes variable names to their values. + + +Zip-archive +----------- + +A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The data dictionary (class parameters) +is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files +are stored under a single .zip archive. + +Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded +string in the JSON file, along with some information that was stored in the serialization. This allows +inspecting stored objects without deserializing the object itself. + +This format allows skipping elements in the file, i.e. we can skip deserializing objects that are +broken/non-serializable. + +.. This can be done via ``custom_objects`` argument to load functions. + + +File structure: + +:: + + saved_model.zip/ + ├── data JSON file of class-parameters (dictionary) + ├── *.optimizer.pth PyTorch optimizers serialized + ├── policy.pth PyTorch state dictionary of the policy saved + ├── pytorch_variables.pth Additional PyTorch variables + ├── _stable_baselines3_version contains the SB3 version with which the model was saved + + +Pros: + +- More robust to unserializable objects (one bad object does not break everything). +- Saved files can be inspected/extracted with zip-archive explorers and by other languages. + + +Cons: + +- More complex implementation. +- Still relies partly on cloudpickle for complex objects (e.g. custom functions) + with can lead to `incompatibilities `_ between Python versions. diff --git a/docs/index.rst b/docs/index.rst index cfbb6fe5b6..c60e5f397a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,6 +26,7 @@ Main Features - Tests, high code coverage and type hints - Clean code - Tensorboard support +- **The performance of each algorithm was tested** (see *Results* section in their respective page) .. toctree:: @@ -48,6 +49,8 @@ Main Features guide/migration guide/checking_nan guide/developer + guide/save_format + guide/export .. toctree:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 554965cd1f..22e6f37769 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -47,6 +47,9 @@ Documentation: - Added first draft of migration guide - Added intro to `imitation `_ library (@shwang) - Enabled doc for ``CnnPolicies`` +- Added advanced saving and loading example +- Added base doc for exporting models +- Added example for getting and setting model parameters Pre-Release 0.9.0 (2020-10-03) diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 90c5acca20..011eb56e80 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -73,6 +73,72 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. obs, rewards, dones, info = env.step(action) env.render() + +Results +------- + +Atari Games +^^^^^^^^^^^ + +The complete learning curves are available in the `associated PR #110 `_. + + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (2M steps) using 6 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+-------------+ +| Environments | A2C | A2C | PPO | PPO | ++==============+==============+==============+==============+=============+ +| | Gaussian | gSDE | Gaussian | gSDE | ++--------------+--------------+--------------+--------------+-------------+ +| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | ++--------------+--------------+--------------+--------------+-------------+ +| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | ++--------------+--------------+--------------+--------------+-------------+ +| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | ++--------------+--------------+--------------+--------------+-------------+ +| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | ++--------------+--------------+--------------+--------------+-------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results (here for PyBullet envs only): + +.. code-block:: bash + + python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results + python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C + + Parameters ---------- diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index 8add6982a5..c14f5da203 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -10,12 +10,19 @@ DDPG trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions. +.. note:: + + As ``DDPG`` can be seen as a special case of its successor :ref:`TD3 `, + they share the same policies and same implementation. + + .. rubric:: Available Policies .. autosummary:: :nosignatures: MlpPolicy + CnnPolicy Notes @@ -25,10 +32,6 @@ Notes - DDPG Paper: https://arxiv.org/abs/1509.02971 - OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html -.. note:: - - The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation. - to match the original paper Can I use? @@ -81,6 +84,66 @@ Example obs, rewards, dones, info = env.step(action) env.render() +Results +------- + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (1M steps) using 6 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters of :ref:`TD3 ` from the `gSDE paper `_ were used for ``DDPG``. + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+ +| Environments | DDPG | TD3 | SAC | ++==============+==============+==============+==============+ +| | Gaussian | Gaussian | gSDE | ++--------------+--------------+--------------+--------------+ +| HalfCheetah | 2272 +/- 69 | 2774 +/- 35 | 2984 +/- 202 | ++--------------+--------------+--------------+--------------+ +| Ant | 1651 +/- 407 | 3305 +/- 43 | 3102 +/- 37 | ++--------------+--------------+--------------+--------------+ +| Hopper | 1201 +/- 211 | 2429 +/- 126 | 2262 +/- 1 | ++--------------+--------------+--------------+--------------+ +| Walker2D | 882 +/- 186 | 2063 +/- 185 | 2136 +/- 67 | ++--------------+--------------+--------------+--------------+ + + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo ddpg --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a ddpg -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ddpg_results + python scripts/plot_from_file.py -i logs/ddpg_results.pkl -latex -l DDPG + + Parameters ---------- diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index ca9ccca322..388307cbea 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -6,7 +6,9 @@ DQN === -`Deep Q Network (DQN) `_ +`Deep Q Network (DQN) `_ builds on `Fitted Q-Iteration (FQI) `_ +and make use of different tricks to stabilize the learning with neural networks: it uses a replay buffer, a target network and gradient clipping. + .. rubric:: Available Policies @@ -74,6 +76,42 @@ Example if done: obs = env.reset() + +Results +------- + +Atari Games +^^^^^^^^^^^ + +The complete learning curves are available in the `associated PR #110 `_. + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoFrameskip-v4``): + +.. code-block:: bash + + python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a dqn -e Pong Breakout -f logs/ -o logs/dqn_results + python scripts/plot_from_file.py -i logs/dqn_results.pkl -latex -l DQN + + Parameters ---------- diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 355b36d496..61a58cde1e 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -89,6 +89,41 @@ Example obs = env.reset() +Results +------- + +This implementation was tested on the `parking env `_ +using 3 seeds. + +The complete learning curves are available in the `associated PR #120 `_. + + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark: + +.. code-block:: bash + + python train.py --algo her --env parking-v0 --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a her -e parking-v0 -f logs/ --no-million + + Parameters ---------- diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index fdcf12196c..091323a256 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -74,6 +74,72 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments. obs, rewards, dones, info = env.step(action) env.render() + +Results +------- + +Atari Games +^^^^^^^^^^^ + +The complete learning curves are available in the `associated PR #110 `_. + + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (2M steps) using 6 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+-------------+ +| Environments | A2C | A2C | PPO | PPO | ++==============+==============+==============+==============+=============+ +| | Gaussian | gSDE | Gaussian | gSDE | ++--------------+--------------+--------------+--------------+-------------+ +| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | ++--------------+--------------+--------------+--------------+-------------+ +| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | ++--------------+--------------+--------------+--------------+-------------+ +| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | ++--------------+--------------+--------------+--------------+-------------+ +| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | ++--------------+--------------+--------------+--------------+-------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results (here for PyBullet envs only): + +.. code-block:: bash + + python scripts/all_plots.py -a ppo -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ppo_results + python scripts/plot_from_file.py -i logs/ppo_results.pkl -latex -l PPO + + Parameters ---------- diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 7b37974c93..bbe6bfc159 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -88,6 +88,66 @@ Example if done: obs = env.reset() + +Results +------- + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (1M steps) using 3 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+ +| Environments | SAC | SAC | TD3 | ++==============+==============+==============+==============+ +| | Gaussian | gSDE | Gaussian | ++--------------+--------------+--------------+--------------+ +| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | ++--------------+--------------+--------------+--------------+ +| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | ++--------------+--------------+--------------+--------------+ +| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | ++--------------+--------------+--------------+--------------+ +| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | ++--------------+--------------+--------------+--------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results + python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC + + Parameters ---------- diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index fbe6aabd50..2ecf8c9d30 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -8,7 +8,7 @@ TD3 `Twin Delayed DDPG (TD3) `_ Addressing Function Approximation Error in Actor-Critic Methods. -TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. +TD3 is a direct successor of :ref:`DDPG ` and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. We recommend reading `OpenAI Spinning guide on TD3 `_ to learn more about those. @@ -18,6 +18,7 @@ We recommend reading `OpenAI Spinning guide on TD3 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+ +| Environments | SAC | SAC | TD3 | ++==============+==============+==============+==============+ +| | Gaussian | gSDE | Gaussian | ++--------------+--------------+--------------+--------------+ +| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | ++--------------+--------------+--------------+--------------+ +| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | ++--------------+--------------+--------------+--------------+ +| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | ++--------------+--------------+--------------+--------------+ +| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | ++--------------+--------------+--------------+--------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results + python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3 + Parameters ----------