From e134a67d3a70950062c948e80dbc181813e091e5 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 23 Oct 2020 20:15:04 +0200 Subject: [PATCH 01/10] Update doc and add new example --- docs/guide/callbacks.rst | 2 +- docs/guide/examples.rst | 50 ++++++++++++++++++++++++++++++++++++++++ docs/guide/rl_tips.rst | 22 +++++++++--------- docs/misc/changelog.rst | 2 +- 4 files changed, 63 insertions(+), 13 deletions(-) diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 7206c7c0e6..6588f90fba 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -15,7 +15,7 @@ To build a custom callback, you need to create a class that derives from ``BaseC This will give you access to events (``_on_training_start``, ``_on_step``) and useful variables (like `self.model` for the RL model). -.. You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see :ref:`Examples `), and one for logging additional values with Tensorboard (see :ref:`Tensorboard section `). +You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see :ref:`Examples `), and one for logging additional values with Tensorboard (see :ref:`Tensorboard section `). .. code-block:: python diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 4b702b7454..723b8bce5d 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -19,6 +19,7 @@ notebooks: - `RL Baselines zoo`_ - `PyBullet`_ - `Hindsight Experience Replay`_ +- `Advanced Saving and Loading`_ .. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb .. _Training, Saving, Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb @@ -28,6 +29,7 @@ notebooks: .. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb .. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb .. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb +.. _Advanced Saving and Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb .. |colab| image:: ../_static/img/colab.svg @@ -417,6 +419,54 @@ The parking env is a goal-conditioned continuous control task, in which the vehi obs = env.reset() +Advanced Saving and Loading +--------------------------------- + +In Stable-Baselines3 (SB3), you can easily create a test environment for periodic evaluation and use a policy independently from a model. + +.. image:: ../_static/img/colab-badge.svg + :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb + +Stable-Baselines3 allows to automatically create an environment for evaluation. +For that, you only to specify ``create_eval_env=True`` when passing the Gym ID of the environment. +Behind the scene, SB3 uses an :ref:`EvalCallback `. + +.. code-block:: python + + from stable_baselines3 import SAC + from stable_baselines3.common.evaluation import evaluate_policy + from stable_baselines3.sac.policies import MlpPolicy + + model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1, learning_rate=1e-3, create_eval_env=True) + # Evaluate the model every 1000 steps on 5 test episodes + # and save the evaluation to the logs folder + model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/") + + # Save the policy independently from the model + # Note: if you don't save the complete model, you cannot continue training afterward + policy = model.policy + policy.save("sac_policy_pendulum.pkl") + + # Retrieve the environment + env = model.get_env() + + # Evaluate the policy + mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True) + + print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") + + # Load the policy independently from the model + saved_policy = MlpPolicy.load("sac_policy_pendulum") + + # Evaluate the loaded policy + mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True) + + print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") + + + + + Record a Video -------------- diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 2226799096..535d04a021 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -146,17 +146,17 @@ for continuous actions problems (cf *Bullet* envs). -.. Goal Environment -.. ----------------- -.. -.. If your environment follows the ``GoalEnv`` interface (cf `HER <../modules/her.html>`_), then you should use -.. HER + (SAC/TD3/DDPG/DQN) depending on the action space. -.. -.. -.. .. note:: -.. -.. The number of workers is an important hyperparameters for experiments with HER -.. +Goal Environment +----------------- + +If your environment follows the ``GoalEnv`` interface (cf :ref:`HER `), then you should use +HER + (SAC/TD3/DDPG/DQN) depending on the action space. + + +.. note:: + + The number of workers is an important hyperparameters for experiments with HER + Tips and Tricks when creating a custom environment diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1474856d36..920e2e66c2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -41,7 +41,7 @@ Documentation: ^^^^^^^^^^^^^^ - Added first draft of migration guide - Enabled doc for ``CnnPolicies`` - +- Added advanced saving and loading example Pre-Release 0.9.0 (2020-10-03) ------------------------------ From 57724dd855de58843aac2ec8a922cdf1e8dc06ed Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 24 Oct 2020 16:09:12 +0200 Subject: [PATCH 02/10] Add save/load replay buffer example --- docs/guide/examples.rst | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 723b8bce5d..5d9f638a64 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -422,13 +422,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi Advanced Saving and Loading --------------------------------- -In Stable-Baselines3 (SB3), you can easily create a test environment for periodic evaluation and use a policy independently from a model. +In this example, we show how to use some advanced features of Stable-Baselines3 (SB3): +how to easily create a test environment to evaluate an agent periodically, +use a policy independently from a model (and how to save it, load it) and save/load a replay buffer. + +By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images). +However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately. + .. image:: ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb Stable-Baselines3 allows to automatically create an environment for evaluation. -For that, you only to specify ``create_eval_env=True`` when passing the Gym ID of the environment. +For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment. Behind the scene, SB3 uses an :ref:`EvalCallback `. .. code-block:: python @@ -437,11 +443,31 @@ Behind the scene, SB3 uses an :ref:`EvalCallback `. from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.sac.policies import MlpPolicy - model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1, learning_rate=1e-3, create_eval_env=True) + # Create the model, the training environment + # and the test environment (for evaluation) + model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1, + learning_rate=1e-3, create_eval_env=True) + # Evaluate the model every 1000 steps on 5 test episodes - # and save the evaluation to the logs folder + # and save the evaluation to the "logs/" folder model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/") + # save the model + model.save("sac_pendulum") + + # the saved model does not contain the replay buffer + loaded_model = SAC.load("sac_pendulum") + print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") + + # now save the replay buffer too + model.save_replay_buffer("sac_replay_buffer") + + # load it into the loaded_model + loaded_model.load_replay_buffer("sac_replay_buffer") + + # now the loaded replay is not empty anymore + print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") + # Save the policy independently from the model # Note: if you don't save the complete model, you cannot continue training afterward policy = model.policy @@ -465,8 +491,6 @@ Behind the scene, SB3 uses an :ref:`EvalCallback `. - - Record a Video -------------- From 4e7155a2a8b3f2a492d816e927837e965cd39fc3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 24 Oct 2020 16:43:07 +0200 Subject: [PATCH 03/10] Add save format + export doc --- docs/guide/custom_policy.rst | 6 ++++ docs/guide/export.rst | 60 ++++++++++++++++++++++++++++++++++++ docs/guide/save_format.rst | 59 +++++++++++++++++++++++++++++++++++ docs/index.rst | 2 ++ docs/misc/changelog.rst | 1 + 5 files changed, 128 insertions(+) create mode 100644 docs/guide/export.rst create mode 100644 docs/guide/save_format.rst diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index b490bd7b14..2c590007a9 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -7,6 +7,12 @@ Stable Baselines3 provides policy networks for images (CnnPolicies) and other type of input features (MlpPolicies). +.. warning:: + For all algorithms (except DDPG, TD3 and SAC), continuous actions are clipped during training and testing + (to avoid out of bound error). + + + Custom Policy Architecture ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/guide/export.rst b/docs/guide/export.rst new file mode 100644 index 0000000000..35a4c40d5e --- /dev/null +++ b/docs/guide/export.rst @@ -0,0 +1,60 @@ +.. _export: + + +Exporting models +================ + +After training an agent, you may want to deploy/use it in an other language +or framework, like `tensorflowjs `_. +Stable Baselines3 does not include tools to export models to other frameworks, but +this document aims to cover parts that are required for exporting along with +more detailed stories from users of Stable Baselines3. + + +Background +---------- + +In Stable Baselines3, the controller is stored inside policies which convert +observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) contains +one policy, accesible via ``model.policy``. + +Policies hold enough information to do the inference (i.e. predict actions), +so it is enough to export these policies (cf :ref:`examples `) +to do inference in an another framework. + +.. warning:: + When using CNN policies, the observation is normalized during pre-preprocessing (dividing by 255 to have values in [0, 1]) + + +Export to ONNX +----------------- + +TODO: contributors help is welcomed! + + +Export to C++ +----------------- + +(using PyTorch JIT) +TODO: contributors help is welcomed! + + +Export to tensorflowjs / ONNX-JS +-------------------------------- + +TODO: contributors help is welcomed! +Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js + + + +Manual export +------------- + +You can also manually export required parameters (weights) and construct the +network in your desired framework. + +You can access parameters of the model via agents' +:func:`get_parameters ` function. +As policies are also PyTorch modules, you also access ``model.policy.state_dict()`` directly. +To find the architecture of the networks for each algorithm, best is to check the ``policies.py`` file located +in their respective folders. diff --git a/docs/guide/save_format.rst b/docs/guide/save_format.rst new file mode 100644 index 0000000000..ed4f399644 --- /dev/null +++ b/docs/guide/save_format.rst @@ -0,0 +1,59 @@ +.. _save_format: + + +On saving and loading +===================== + +Stable Baselines3 (SB3) stores both neural network parameters and algorithm-related parameters such as +exploration schedule, number of environments and observation/action space. This allows continual learning and easy +use of trained agents without training, but it is not without its issues. Following describes the format +used to save agents in SB3 along with its pros and shortcomings. + +Terminology used in this page: + +- *parameters* refer to neural network parameters (also called "weights"). This is a dictionary + mapping variable name to a PyTorch tensor. +- *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space. + These depend on the algorithm used. This is a dictionary mapping classes variable names their values. + + +Zip-archive +----------- + +A zip-archived JSON dump, PyTorch state dictionnaries and PyTorch variables. The data dictionary (class parameters) +is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files +are stored under a single .zip archive. + +Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded +string in the JSON file, along with some information that was stored in the serialization. This allows +inspecting stored objects without deserializing the object itself. + +This format allows skipping elements in the file, i.e. we can skip deserializing objects that are +broken/non-serializable. + +.. This can be done via ``custom_objects`` argument to load functions. + + +File structure: + +:: + + saved_model.zip/ + ├── data JSON file of class-parameters (dictionary) + ├── *.optimizer.pth PyTorch optimizers serialized + ├── policy.pth PyTorch state dictionary of the policy saved + ├── pytorch_variables.pth Additional PyTorch variables + ├── _stable_baselines3_version contains the SB3 version with which the model was saved + + +Pros: + +- More robust to unserializable objects (one bad object does not break everything). +- Saved file can be inspected/extracted with zip-archive explorers and by other languages. + + +Cons: + +- More complex implementation. +- Still relies partly on cloudpickle for complex objects (e.g. custom functions) + with can lead to `incompatibilities `_ between Python versions. diff --git a/docs/index.rst b/docs/index.rst index 9b7edea283..fee63a1026 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,8 @@ Main Features guide/migration guide/checking_nan guide/developer + guide/save_format + guide/export .. toctree:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 920e2e66c2..9835648640 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -42,6 +42,7 @@ Documentation: - Added first draft of migration guide - Enabled doc for ``CnnPolicies`` - Added advanced saving and loading example +- Added base doc for exporting models Pre-Release 0.9.0 (2020-10-03) ------------------------------ From 844b913955feef0ef8b96f3ea225d8c1c18ef148 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 24 Oct 2020 17:35:25 +0200 Subject: [PATCH 04/10] Add example for get/set parameters --- docs/guide/examples.rst | 92 +++++++++++++++++++++++++++++++++++++++++ docs/misc/changelog.rst | 2 + 2 files changed, 94 insertions(+) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 5d9f638a64..f9ce6a6a4e 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -491,6 +491,98 @@ Behind the scene, SB3 uses an :ref:`EvalCallback `. +Accessing and modifying model parameters +---------------------------------------- + +You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions, +or via ``model.policy.state_dict()`` (and ``load_state_dict()``), +which use dictionaries that map variable names to PyTorch tensors. + +These functions are useful when you need to e.g. evaluate large set of models with same network structure, +visualize different layers of the network or modify parameters manually. + +Policies also offers a simple way to save/load weights as a NumPy vector, using ``parameters_to_vector()`` +and ``load_from_vector()`` method. + +Following example demonstrates reading parameters, modifying some of them and loading them to model +by implementing `evolution strategy (es) `_ +for solving ``CartPole-v1`` environment. The initial guess for parameters is obtained by running +A2C policy gradient updates on the model. + +.. code-block:: python + + from typing import Dict + + import gym + import numpy as np + import torch as th + + from stable_baselines3 import A2C + from stable_baselines3.common.evaluation import evaluate_policy + + + def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]: + """Mutate parameters by adding normal noise to them""" + return dict((name, param + th.randn_like(param)) for name, param in params.items()) + + + # Create policy with a small network + model = A2C( + "MlpPolicy", + "CartPole-v1", + ent_coef=0.0, + policy_kwargs={"net_arch": [32]}, + seed=0, + learning_rate=0.05, + ) + + # Use traditional actor-critic policy gradient updates to + # find good initial parameters + model.learn(total_timesteps=10000) + + # Include only variables with "policy", "action" (policy) or "shared_net" (shared layers) + # in their name: only these ones affect the action. + # NOTE: you can retrieve those parameters using model.get_parameters() too + mean_params = dict( + (key, value) + for key, value in model.policy.state_dict().items() + if ("policy" in key or "shared_net" in key or "action" in key) + ) + + # population size of 50 invdiduals + pop_size = 50 + # Keep top 10% + n_elite = pop_size // 10 + # Retrieve the environment + env = model.get_env() + + for iteration in range(10): + # Create population of candidates and evaluate them + population = [] + for population_i in range(pop_size): + candidate = mutate(mean_params) + # Load new policy parameters to agent. + # Tell function that it should only update parameters + # we give it (policy parameters) + model.policy.load_state_dict(candidate, strict=False) + # Evaluate the candidate + fitness, _ = evaluate_policy(model, env) + population.append((candidate, fitness)) + # Take top 10% and use average over their parameters as next mean parameter + top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite] + mean_params = dict( + ( + name, + th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0), + ) + for name in mean_params.keys() + ) + mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite + print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}") + print(f"Best fitness: {top_candidates[0][1]:.2f}") + + + Record a Video -------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9835648640..da898d7f03 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -43,6 +43,8 @@ Documentation: - Enabled doc for ``CnnPolicies`` - Added advanced saving and loading example - Added base doc for exporting models +- Added example for getting and setting model parameters + Pre-Release 0.9.0 (2020-10-03) ------------------------------ From d613f182bc767b24a52c9fff9a07f8851f56f4b8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 25 Oct 2020 11:54:25 +0100 Subject: [PATCH 05/10] Typos and minor edits --- docs/guide/export.rst | 9 +++++++-- docs/guide/migration.rst | 1 + docs/guide/save_format.rst | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/guide/export.rst b/docs/guide/export.rst index 35a4c40d5e..30622a28a8 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -16,7 +16,7 @@ Background In Stable Baselines3, the controller is stored inside policies which convert observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) contains -one policy, accesible via ``model.policy``. +one policy, accessible via ``model.policy``. Policies hold enough information to do the inference (i.e. predict actions), so it is enough to export these policies (cf :ref:`examples `) @@ -55,6 +55,11 @@ network in your desired framework. You can access parameters of the model via agents' :func:`get_parameters ` function. -As policies are also PyTorch modules, you also access ``model.policy.state_dict()`` directly. +As policies are also PyTorch modules, you can also access ``model.policy.state_dict()`` directly. To find the architecture of the networks for each algorithm, best is to check the ``policies.py`` file located in their respective folders. + +.. note:: + + In most cases, we recommend using PyTorch methods ``state_dict()`` and ``load_state_dict()`` from the policy, + unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters()``. diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 0d54ad4c15..3bbee833e9 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -59,6 +59,7 @@ Moved Files - ``bench/monitor.py`` -> ``common/monitor.py`` - ``logger.py`` -> ``common/logger.py`` - ``results_plotter.py`` -> ``common/results_plotter.py`` +- ``common/cmd_util.py`` -> ``common/env_util.py`` Utility functions are no longer exported from ``common`` module, you should import them with their absolute path, e.g.: diff --git a/docs/guide/save_format.rst b/docs/guide/save_format.rst index ed4f399644..caa4f3f806 100644 --- a/docs/guide/save_format.rst +++ b/docs/guide/save_format.rst @@ -20,7 +20,7 @@ Terminology used in this page: Zip-archive ----------- -A zip-archived JSON dump, PyTorch state dictionnaries and PyTorch variables. The data dictionary (class parameters) +A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The data dictionary (class parameters) is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files are stored under a single .zip archive. From 8778dbc926c6d6018ce6ff035c5da261aac48223 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 25 Oct 2020 12:31:51 +0100 Subject: [PATCH 06/10] Add results sections --- docs/modules/a2c.rst | 66 ++++++++++++++++++++++++++++++++++++++++++++ docs/modules/dqn.rst | 36 ++++++++++++++++++++++++ docs/modules/her.rst | 35 +++++++++++++++++++++++ docs/modules/ppo.rst | 66 ++++++++++++++++++++++++++++++++++++++++++++ docs/modules/sac.rst | 60 ++++++++++++++++++++++++++++++++++++++++ docs/modules/td3.rst | 58 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 321 insertions(+) diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 90c5acca20..011eb56e80 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -73,6 +73,72 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. obs, rewards, dones, info = env.step(action) env.render() + +Results +------- + +Atari Games +^^^^^^^^^^^ + +The complete learning curves are available in the `associated PR #110 `_. + + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (2M steps) using 6 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+-------------+ +| Environments | A2C | A2C | PPO | PPO | ++==============+==============+==============+==============+=============+ +| | Gaussian | gSDE | Gaussian | gSDE | ++--------------+--------------+--------------+--------------+-------------+ +| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | ++--------------+--------------+--------------+--------------+-------------+ +| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | ++--------------+--------------+--------------+--------------+-------------+ +| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | ++--------------+--------------+--------------+--------------+-------------+ +| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | ++--------------+--------------+--------------+--------------+-------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results (here for PyBullet envs only): + +.. code-block:: bash + + python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results + python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C + + Parameters ---------- diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index ca9ccca322..e916a6168b 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -74,6 +74,42 @@ Example if done: obs = env.reset() + +Results +------- + +Atari Games +^^^^^^^^^^^ + +The complete learning curves are available in the `associated PR #110 `_. + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoFrameskip-v4``): + +.. code-block:: bash + + python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a dqn -e Pong Breakout -f logs/ -o logs/dqn_results + python scripts/plot_from_file.py -i logs/dqn_results.pkl -latex -l DQN + + Parameters ---------- diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 355b36d496..61a58cde1e 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -89,6 +89,41 @@ Example obs = env.reset() +Results +------- + +This implementation was tested on the `parking env `_ +using 3 seeds. + +The complete learning curves are available in the `associated PR #120 `_. + + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark: + +.. code-block:: bash + + python train.py --algo her --env parking-v0 --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a her -e parking-v0 -f logs/ --no-million + + Parameters ---------- diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index fdcf12196c..091323a256 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -74,6 +74,72 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments. obs, rewards, dones, info = env.step(action) env.render() + +Results +------- + +Atari Games +^^^^^^^^^^^ + +The complete learning curves are available in the `associated PR #110 `_. + + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (2M steps) using 6 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+-------------+ +| Environments | A2C | A2C | PPO | PPO | ++==============+==============+==============+==============+=============+ +| | Gaussian | gSDE | Gaussian | gSDE | ++--------------+--------------+--------------+--------------+-------------+ +| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | ++--------------+--------------+--------------+--------------+-------------+ +| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | ++--------------+--------------+--------------+--------------+-------------+ +| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | ++--------------+--------------+--------------+--------------+-------------+ +| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | ++--------------+--------------+--------------+--------------+-------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results (here for PyBullet envs only): + +.. code-block:: bash + + python scripts/all_plots.py -a ppo -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ppo_results + python scripts/plot_from_file.py -i logs/ppo_results.pkl -latex -l PPO + + Parameters ---------- diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 7b37974c93..bbe6bfc159 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -88,6 +88,66 @@ Example if done: obs = env.reset() + +Results +------- + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (1M steps) using 3 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+ +| Environments | SAC | SAC | TD3 | ++==============+==============+==============+==============+ +| | Gaussian | gSDE | Gaussian | ++--------------+--------------+--------------+--------------+ +| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | ++--------------+--------------+--------------+--------------+ +| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | ++--------------+--------------+--------------+--------------+ +| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | ++--------------+--------------+--------------+--------------+ +| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | ++--------------+--------------+--------------+--------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results + python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC + + Parameters ---------- diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index fbe6aabd50..118b2f4b41 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -84,6 +84,64 @@ Example obs, rewards, dones, info = env.step(action) env.render() +Results +------- + +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (1M steps) using 3 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters from the `gSDE paper `_ were used (as they are tuned for PyBullet envs). + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+ +| Environments | SAC | SAC | TD3 | ++==============+==============+==============+==============+ +| | Gaussian | gSDE | Gaussian | ++--------------+--------------+--------------+--------------+ +| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | ++--------------+--------------+--------------+--------------+ +| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | ++--------------+--------------+--------------+--------------+ +| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | ++--------------+--------------+--------------+--------------+ +| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | ++--------------+--------------+--------------+--------------+ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results + python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3 + Parameters ---------- From 4e7578e3b1193ae0de921ef7d076257f8b09481a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 25 Oct 2020 12:46:17 +0100 Subject: [PATCH 07/10] Add note about performance --- README.md | 3 +++ docs/index.rst | 1 + docs/modules/ddpg.rst | 17 +++++++++++++---- docs/modules/dqn.rst | 4 +++- docs/modules/td3.rst | 3 ++- 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 3415e4d091..7f928416e9 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,9 @@ These algorithms will make it easier for the research community and industry to ## Main Features +**The performance of each algorithm was tested** (see *Results* section in their respective page), +you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. + | **Features** | **Stable-Baselines3** | | --------------------------- | ----------------------| diff --git a/docs/index.rst b/docs/index.rst index 6a29092f7f..c60e5f397a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,6 +26,7 @@ Main Features - Tests, high code coverage and type hints - Clean code - Tensorboard support +- **The performance of each algorithm was tested** (see *Results* section in their respective page) .. toctree:: diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index 8add6982a5..f9d447e07d 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -10,12 +10,19 @@ DDPG trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions. +.. note:: + + As ``DDPG`` can be seen as a special case of its successor :ref:`TD3 `, + they share the same policies. + + .. rubric:: Available Policies .. autosummary:: :nosignatures: MlpPolicy + CnnPolicy Notes @@ -25,10 +32,6 @@ Notes - DDPG Paper: https://arxiv.org/abs/1509.02971 - OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html -.. note:: - - The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation. - to match the original paper Can I use? @@ -81,6 +84,12 @@ Example obs, rewards, dones, info = env.step(action) env.render() +Results +------- + +As ``DDPG`` is currently treated as a special case of :ref:`TD3 `, +this implementation can be trusted as TD3 results are macthing the one from the original implementation. + Parameters ---------- diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index e916a6168b..388307cbea 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -6,7 +6,9 @@ DQN === -`Deep Q Network (DQN) `_ +`Deep Q Network (DQN) `_ builds on `Fitted Q-Iteration (FQI) `_ +and make use of different tricks to stabilize the learning with neural networks: it uses a replay buffer, a target network and gradient clipping. + .. rubric:: Available Policies diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 118b2f4b41..2ecf8c9d30 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -8,7 +8,7 @@ TD3 `Twin Delayed DDPG (TD3) `_ Addressing Function Approximation Error in Actor-Critic Methods. -TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. +TD3 is a direct successor of :ref:`DDPG ` and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. We recommend reading `OpenAI Spinning guide on TD3 `_ to learn more about those. @@ -18,6 +18,7 @@ We recommend reading `OpenAI Spinning guide on TD3 Date: Tue, 27 Oct 2020 11:13:12 +0100 Subject: [PATCH 08/10] Add DDPG results --- docs/modules/ddpg.rst | 60 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index f9d447e07d..c14f5da203 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -13,7 +13,7 @@ trick for DQN with the deterministic policy gradient, to obtain an algorithm for .. note:: As ``DDPG`` can be seen as a special case of its successor :ref:`TD3 `, - they share the same policies. + they share the same policies and same implementation. .. rubric:: Available Policies @@ -87,8 +87,62 @@ Example Results ------- -As ``DDPG`` is currently treated as a special case of :ref:`TD3 `, -this implementation can be trusted as TD3 results are macthing the one from the original implementation. +PyBullet Environments +^^^^^^^^^^^^^^^^^^^^^ + +Results on the PyBullet benchmark (1M steps) using 6 seeds. +The complete learning curves are available in the `associated issue #48 `_. + + +.. note:: + + Hyperparameters of :ref:`TD3 ` from the `gSDE paper `_ were used for ``DDPG``. + + +*Gaussian* means that the unstructured Gaussian noise is used for exploration, +*gSDE* (generalized State-Dependent Exploration) is used otherwise. + ++--------------+--------------+--------------+--------------+ +| Environments | DDPG | TD3 | SAC | ++==============+==============+==============+==============+ +| | Gaussian | Gaussian | gSDE | ++--------------+--------------+--------------+--------------+ +| HalfCheetah | 2272 +/- 69 | 2774 +/- 35 | 2984 +/- 202 | ++--------------+--------------+--------------+--------------+ +| Ant | 1651 +/- 407 | 3305 +/- 43 | 3102 +/- 37 | ++--------------+--------------+--------------+--------------+ +| Hopper | 1201 +/- 211 | 2429 +/- 126 | 2262 +/- 1 | ++--------------+--------------+--------------+--------------+ +| Walker2D | 882 +/- 186 | 2063 +/- 185 | 2136 +/- 67 | ++--------------+--------------+--------------+--------------+ + + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo ddpg --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a ddpg -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ddpg_results + python scripts/plot_from_file.py -i logs/ddpg_results.pkl -latex -l DDPG + Parameters From 47baaca44dd05dd7709af1065ca5db62f9ac2566 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 27 Oct 2020 22:16:12 +0100 Subject: [PATCH 09/10] Address comments --- docs/guide/custom_policy.rst | 5 +++-- docs/guide/examples.rst | 5 +++-- docs/guide/export.rst | 8 +++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 2c590007a9..887fea325c 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -8,8 +8,9 @@ and other type of input features (MlpPolicies). .. warning:: - For all algorithms (except DDPG, TD3 and SAC), continuous actions are clipped during training and testing - (to avoid out of bound error). + For A2C and PPO, continuous actions are clipped during training and testing + (to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a ``tanh()`` transformation, + which allows to correctly handle bounds. diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index f9ce6a6a4e..a47c734f6a 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -434,7 +434,7 @@ However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` me :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb Stable-Baselines3 allows to automatically create an environment for evaluation. -For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment. +For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent. Behind the scene, SB3 uses an :ref:`EvalCallback `. .. code-block:: python @@ -469,7 +469,8 @@ Behind the scene, SB3 uses an :ref:`EvalCallback `. print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer") # Save the policy independently from the model - # Note: if you don't save the complete model, you cannot continue training afterward + # Note: if you don't save the complete model with `model.save()` + # you cannot continue training afterward policy = model.policy policy.save("sac_policy_pendulum.pkl") diff --git a/docs/guide/export.rst b/docs/guide/export.rst index 30622a28a8..ccbb99cbe7 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -15,15 +15,17 @@ Background ---------- In Stable Baselines3, the controller is stored inside policies which convert -observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) contains -one policy, accessible via ``model.policy``. +observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) +contains a policy object which represents the currently learned behavior, +accessible via ``model.policy``. Policies hold enough information to do the inference (i.e. predict actions), so it is enough to export these policies (cf :ref:`examples `) to do inference in an another framework. .. warning:: - When using CNN policies, the observation is normalized during pre-preprocessing (dividing by 255 to have values in [0, 1]) + When using CNN policies, the observation is normalized during pre-preprocessing. + This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1]) Export to ONNX From f11db607751f659f09036a020364bc8292c06e51 Mon Sep 17 00:00:00 2001 From: "Anssi \"Miffyli\" Kanervisto" Date: Wed, 28 Oct 2020 03:03:25 +0200 Subject: [PATCH 10/10] Fix grammar/wording --- docs/guide/custom_policy.rst | 2 +- docs/guide/examples.rst | 4 ++-- docs/guide/export.rst | 8 ++++---- docs/guide/save_format.rst | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 887fea325c..8970cda0a8 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -10,7 +10,7 @@ and other type of input features (MlpPolicies). .. warning:: For A2C and PPO, continuous actions are clipped during training and testing (to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a ``tanh()`` transformation, - which allows to correctly handle bounds. + which handles bounds more correctly. diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a47c734f6a..2e9f2b2384 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -433,7 +433,7 @@ However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` me .. image:: ../_static/img/colab-badge.svg :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb -Stable-Baselines3 allows to automatically create an environment for evaluation. +Stable-Baselines3 automatic creation of an environment for evaluation. For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent. Behind the scene, SB3 uses an :ref:`EvalCallback `. @@ -507,7 +507,7 @@ and ``load_from_vector()`` method. Following example demonstrates reading parameters, modifying some of them and loading them to model by implementing `evolution strategy (es) `_ -for solving ``CartPole-v1`` environment. The initial guess for parameters is obtained by running +for solving the ``CartPole-v1`` environment. The initial guess for parameters is obtained by running A2C policy gradient updates on the model. .. code-block:: python diff --git a/docs/guide/export.rst b/docs/guide/export.rst index ccbb99cbe7..8be495194b 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -4,7 +4,7 @@ Exporting models ================ -After training an agent, you may want to deploy/use it in an other language +After training an agent, you may want to deploy/use it in another language or framework, like `tensorflowjs `_. Stable Baselines3 does not include tools to export models to other frameworks, but this document aims to cover parts that are required for exporting along with @@ -21,7 +21,7 @@ accessible via ``model.policy``. Policies hold enough information to do the inference (i.e. predict actions), so it is enough to export these policies (cf :ref:`examples `) -to do inference in an another framework. +to do inference in another framework. .. warning:: When using CNN policies, the observation is normalized during pre-preprocessing. @@ -31,14 +31,14 @@ to do inference in an another framework. Export to ONNX ----------------- -TODO: contributors help is welcomed! +TODO: help is welcomed! Export to C++ ----------------- (using PyTorch JIT) -TODO: contributors help is welcomed! +TODO: help is welcomed! Export to tensorflowjs / ONNX-JS diff --git a/docs/guide/save_format.rst b/docs/guide/save_format.rst index caa4f3f806..38dc233747 100644 --- a/docs/guide/save_format.rst +++ b/docs/guide/save_format.rst @@ -14,7 +14,7 @@ Terminology used in this page: - *parameters* refer to neural network parameters (also called "weights"). This is a dictionary mapping variable name to a PyTorch tensor. - *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space. - These depend on the algorithm used. This is a dictionary mapping classes variable names their values. + These depend on the algorithm used. This is a dictionary mapping classes variable names to their values. Zip-archive @@ -49,7 +49,7 @@ File structure: Pros: - More robust to unserializable objects (one bad object does not break everything). -- Saved file can be inspected/extracted with zip-archive explorers and by other languages. +- Saved files can be inspected/extracted with zip-archive explorers and by other languages. Cons: