Skip to content
Merged
9 changes: 6 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ Changelog
==========


Release 1.1.0a3 (WIP)
Release 1.1.0a4 (WIP)
---------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``.

New Features:
^^^^^^^^^^^^^
Expand All @@ -30,15 +31,17 @@ Others:
^^^^^^^
- Added ``flake8-bugbear`` to tests dependencies to find likely bugs
- Added Code of Conduct
- Added tests for GAE and lambda return computation

Documentation:
^^^^^^^^^^^^^^
- Added gym pybullet drones project (@JacopoPan)
- Added link to SuperSuit in projects (@justinkterry)
- Fixed DQN example (thanks @ltbd78)
- Clarify channel-first/channel-last recommendation
- Clarified channel-first/channel-last recommendation
- Update sphinx environment installation instructions (@tom-doerr)
- Clarify pip installation in Zsh (@tom-doerr)
- Clarified pip installation in Zsh (@tom-doerr)
- Clarified return computation for on-policy algorithms (TD(lambda) estimate was used)
- Added example for using ``ProcgenEnv``


Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Schedule]
self._last_obs = None # type: Optional[np.ndarray]
self._last_dones = None # type: Optional[np.ndarray]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[np.ndarray]
self._episode_num = 0
Expand Down Expand Up @@ -377,7 +377,7 @@ def _setup_learn(
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset()
self._last_dones = np.zeros((self.env.num_envs,), dtype=bool)
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
Expand Down
37 changes: 25 additions & 12 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.dones, self.values, self.log_probs = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

Expand All @@ -303,7 +303,7 @@ def reset(self) -> None:
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand All @@ -312,20 +312,25 @@ def reset(self) -> None:

def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the returns (sum of discounted rewards)
and GAE advantage.
Adapted from Stable-Baselines PPO2.
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.

Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S))
where R is the discounted reward with value bootstrap,
set ``gae_lambda=1.0`` during initialization.

:param last_values:
:param dones:
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))

For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.

:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).

"""
# convert to numpy
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()

last_gae_lam = 0
Expand All @@ -334,21 +339,29 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.dones[step + 1]
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values

def add(
self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param done: End of episode signal.
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
Expand All @@ -366,7 +379,7 @@ def add(
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
self.dones[self.pos] = np.array(done).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def collect_rollouts(
if isinstance(self.action_space, gym.spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
self._last_obs = new_obs
self._last_dones = dones
self._last_episode_starts = dones

with th.no_grad():
# Compute value for the last timestep
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0a3
1.1.0a4
114 changes: 114 additions & 0 deletions tests/test_gae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C, PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy


class CustomEnv(gym.Env):
def __init__(self, max_steps=8):
super(CustomEnv, self).__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.max_steps = max_steps
self.n_steps = 0

def seed(self, seed):
self.observation_space.seed(seed)

def reset(self):
self.n_steps = 0
return self.observation_space.sample()

def step(self, action):
self.n_steps += 1

done = False
reward = 0.0
if self.n_steps >= self.max_steps:
reward = 1.0
done = True

return self.observation_space.sample(), reward, done, {}


class CheckGAECallback(BaseCallback):
def __init__(self):
super(CheckGAECallback, self).__init__(verbose=0)

def _on_rollout_end(self):
buffer = self.model.rollout_buffer
rollout_size = buffer.size()

max_steps = self.training_env.envs[0].max_steps
gamma = self.model.gamma
gae_lambda = self.model.gae_lambda
value = self.model.policy.constant_value
# We know in advance that the agent will get a single
# reward at the very last timestep of the episode,
# so we can pre-compute the lambda-return and advantage
deltas = np.zeros((rollout_size,))
advantages = np.zeros((rollout_size,))
# Reward should be 1.0 on final timestep of episode
rewards = np.zeros((rollout_size,))
rewards[max_steps - 1 :: max_steps] = 1.0
# Note that these are episode starts (+1 timestep from done)
episode_starts = np.zeros((rollout_size,))
episode_starts[::max_steps] = 1.0

# Final step is always terminal (next would episode_start = 1)
deltas[-1] = rewards[-1] - value
advantages[-1] = deltas[-1]
for n in reversed(range(rollout_size - 1)):
# Values are constants
episode_start_mask = 1.0 - episode_starts[n + 1]
deltas[n] = rewards[n] + gamma * value * episode_start_mask - value
advantages[n] = deltas[n] + gamma * gae_lambda * advantages[n + 1] * episode_start_mask

# TD(lambda) estimate, see Github PR #375
lambda_returns = advantages + value

assert np.allclose(buffer.advantages.flatten(), advantages)
assert np.allclose(buffer.returns.flatten(), lambda_returns)

def _on_step(self):
return True


class CustomPolicy(ActorCriticPolicy):
"""Custom Policy with a constant value function"""

def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs)
self.constant_value = 0.0

def forward(self, obs, deterministic=False):
actions, values, log_prob = super().forward(obs, deterministic)
# Overwrite values with ones
values = th.ones_like(values) * self.constant_value
return actions, values, log_prob


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("gae_lambda", [1.0, 0.9])
@pytest.mark.parametrize("gamma", [1.0, 0.99])
@pytest.mark.parametrize("num_episodes", [1, 3])
def test_gae_computation(model_class, gae_lambda, gamma, num_episodes):
env = CustomEnv(max_steps=64)
rollout_size = 64 * num_episodes
model = model_class(
CustomPolicy,
env,
seed=1,
gamma=gamma,
n_steps=rollout_size,
gae_lambda=gae_lambda,
)
model.learn(rollout_size, callback=CheckGAECallback())

# Change constant value so advantage != returns
model.policy.constant_value = 1.0
model.learn(rollout_size, callback=CheckGAECallback())