diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b0228a76d4..14f1dc9af2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,11 +4,12 @@ Changelog ========== -Release 1.1.0a3 (WIP) +Release 1.1.0a4 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``. New Features: ^^^^^^^^^^^^^ @@ -30,15 +31,17 @@ Others: ^^^^^^^ - Added ``flake8-bugbear`` to tests dependencies to find likely bugs - Added Code of Conduct +- Added tests for GAE and lambda return computation Documentation: ^^^^^^^^^^^^^^ - Added gym pybullet drones project (@JacopoPan) - Added link to SuperSuit in projects (@justinkterry) - Fixed DQN example (thanks @ltbd78) -- Clarify channel-first/channel-last recommendation +- Clarified channel-first/channel-last recommendation - Update sphinx environment installation instructions (@tom-doerr) -- Clarify pip installation in Zsh (@tom-doerr) +- Clarified pip installation in Zsh (@tom-doerr) +- Clarified return computation for on-policy algorithms (TD(lambda) estimate was used) - Added example for using ``ProcgenEnv`` diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 3164df1ffd..819a67a241 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -130,7 +130,7 @@ def __init__( self.tensorboard_log = tensorboard_log self.lr_schedule = None # type: Optional[Schedule] self._last_obs = None # type: Optional[np.ndarray] - self._last_dones = None # type: Optional[np.ndarray] + self._last_episode_starts = None # type: Optional[np.ndarray] # When using VecNormalize: self._last_original_obs = None # type: Optional[np.ndarray] self._episode_num = 0 @@ -377,7 +377,7 @@ def _setup_learn( # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: self._last_obs = self.env.reset() - self._last_dones = np.zeros((self.env.num_envs,), dtype=bool) + self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: self._last_original_obs = self._vec_normalize_env.get_original_obs() diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index ff7897baf6..a6cba8c0d3 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -294,7 +294,7 @@ def __init__( self.gae_lambda = gae_lambda self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None - self.returns, self.dones, self.values, self.log_probs = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None self.generator_ready = False self.reset() @@ -303,7 +303,7 @@ def reset(self) -> None: self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -312,20 +312,25 @@ def reset(self) -> None: def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: """ - Post-processing step: compute the returns (sum of discounted rewards) - and GAE advantage. - Adapted from Stable-Baselines PPO2. + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S)) where R is the discounted reward with value bootstrap, set ``gae_lambda=1.0`` during initialization. - :param last_values: - :param dones: + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) + + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. + + :param last_values: state value estimation for the last step (one for each env) + :param dones: if the last step was a terminal step (one bool for each env). """ - # convert to numpy + # Convert to numpy last_values = last_values.clone().cpu().numpy().flatten() last_gae_lam = 0 @@ -334,21 +339,29 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra next_non_terminal = 1.0 - dones next_values = last_values else: - next_non_terminal = 1.0 - self.dones[step + 1] + next_non_terminal = 1.0 - self.episode_starts[step + 1] next_values = self.values[step + 1] delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA self.returns = self.advantages + self.values def add( - self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor + self, + obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, ) -> None: """ :param obs: Observation :param action: Action :param reward: - :param done: End of episode signal. + :param episode_start: Start of episode signal. :param value: estimated value of the current state following the current policy. :param log_prob: log probability of the action @@ -366,7 +379,7 @@ def add( self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() - self.dones[self.pos] = np.array(done).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.pos += 1 diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index c55acb8264..64300dc922 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -180,9 +180,9 @@ def collect_rollouts( if isinstance(self.action_space, gym.spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) - rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs) + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) self._last_obs = new_obs - self._last_dones = dones + self._last_episode_starts = dones with th.no_grad(): # Compute value for the last timestep diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 55edcadbe1..cd64e7162a 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a3 +1.1.0a4 diff --git a/tests/test_gae.py b/tests/test_gae.py new file mode 100644 index 0000000000..7f095c05cd --- /dev/null +++ b/tests/test_gae.py @@ -0,0 +1,114 @@ +import gym +import numpy as np +import pytest +import torch as th + +from stable_baselines3 import A2C, PPO +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.policies import ActorCriticPolicy + + +class CustomEnv(gym.Env): + def __init__(self, max_steps=8): + super(CustomEnv, self).__init__() + self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.max_steps = max_steps + self.n_steps = 0 + + def seed(self, seed): + self.observation_space.seed(seed) + + def reset(self): + self.n_steps = 0 + return self.observation_space.sample() + + def step(self, action): + self.n_steps += 1 + + done = False + reward = 0.0 + if self.n_steps >= self.max_steps: + reward = 1.0 + done = True + + return self.observation_space.sample(), reward, done, {} + + +class CheckGAECallback(BaseCallback): + def __init__(self): + super(CheckGAECallback, self).__init__(verbose=0) + + def _on_rollout_end(self): + buffer = self.model.rollout_buffer + rollout_size = buffer.size() + + max_steps = self.training_env.envs[0].max_steps + gamma = self.model.gamma + gae_lambda = self.model.gae_lambda + value = self.model.policy.constant_value + # We know in advance that the agent will get a single + # reward at the very last timestep of the episode, + # so we can pre-compute the lambda-return and advantage + deltas = np.zeros((rollout_size,)) + advantages = np.zeros((rollout_size,)) + # Reward should be 1.0 on final timestep of episode + rewards = np.zeros((rollout_size,)) + rewards[max_steps - 1 :: max_steps] = 1.0 + # Note that these are episode starts (+1 timestep from done) + episode_starts = np.zeros((rollout_size,)) + episode_starts[::max_steps] = 1.0 + + # Final step is always terminal (next would episode_start = 1) + deltas[-1] = rewards[-1] - value + advantages[-1] = deltas[-1] + for n in reversed(range(rollout_size - 1)): + # Values are constants + episode_start_mask = 1.0 - episode_starts[n + 1] + deltas[n] = rewards[n] + gamma * value * episode_start_mask - value + advantages[n] = deltas[n] + gamma * gae_lambda * advantages[n + 1] * episode_start_mask + + # TD(lambda) estimate, see Github PR #375 + lambda_returns = advantages + value + + assert np.allclose(buffer.advantages.flatten(), advantages) + assert np.allclose(buffer.returns.flatten(), lambda_returns) + + def _on_step(self): + return True + + +class CustomPolicy(ActorCriticPolicy): + """Custom Policy with a constant value function""" + + def __init__(self, *args, **kwargs): + super(CustomPolicy, self).__init__(*args, **kwargs) + self.constant_value = 0.0 + + def forward(self, obs, deterministic=False): + actions, values, log_prob = super().forward(obs, deterministic) + # Overwrite values with ones + values = th.ones_like(values) * self.constant_value + return actions, values, log_prob + + +@pytest.mark.parametrize("model_class", [A2C, PPO]) +@pytest.mark.parametrize("gae_lambda", [1.0, 0.9]) +@pytest.mark.parametrize("gamma", [1.0, 0.99]) +@pytest.mark.parametrize("num_episodes", [1, 3]) +def test_gae_computation(model_class, gae_lambda, gamma, num_episodes): + env = CustomEnv(max_steps=64) + rollout_size = 64 * num_episodes + model = model_class( + CustomPolicy, + env, + seed=1, + gamma=gamma, + n_steps=rollout_size, + gae_lambda=gae_lambda, + ) + model.learn(rollout_size, callback=CheckGAECallback()) + + # Change constant value so advantage != returns + model.policy.constant_value = 1.0 + model.learn(rollout_size, callback=CheckGAECallback())