Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
68edcc7
refactor stacking obs
qgallouedec Dec 24, 2022
1990204
Improve docstring
qgallouedec Dec 24, 2022
a3eca5e
remove all StackedDictObservations
qgallouedec Dec 24, 2022
d558cba
Update tests and make stacked obs clearer
qgallouedec Dec 24, 2022
05608a7
Fix type check
qgallouedec Dec 24, 2022
ece5b0a
fix stacked_observation_space
qgallouedec Dec 24, 2022
41f208f
undo init change, deprecate StackedDictObservations
qgallouedec Dec 25, 2022
fc19fed
deprecate stack_observation_space
qgallouedec Dec 25, 2022
2490bac
type hints
qgallouedec Dec 25, 2022
c1277da
ignore pytype errors
qgallouedec Dec 26, 2022
e5df306
undo vecenv doc change
qgallouedec Dec 26, 2022
0db3b0e
Deprecation warning in StackedDictObs doctstring
qgallouedec Dec 26, 2022
cf1e199
Fix vec_env.rst
qgallouedec Dec 26, 2022
6e507aa
Fix __all__ sorting
qgallouedec Dec 26, 2022
2f61a06
fix pytype ignore statement
qgallouedec Dec 26, 2022
2982401
Merge branch 'new-stack-observation' of https://github.com/DLR-RM/sta…
qgallouedec Dec 26, 2022
aad068a
Merge branch 'master' into new-stack-observation
qgallouedec Jan 2, 2023
2ca28b2
Merge branch 'master' into new-stack-observation
araffin Jan 13, 2023
985e825
Merge branch 'master' into new-stack-observation
araffin Jan 23, 2023
232f3eb
Merge branch 'master' into new-stack-observation
araffin Jan 30, 2023
3fdb18f
Update docstring
araffin Jan 30, 2023
927ab27
Merge branch 'master' into new-stack-observation
qgallouedec Feb 2, 2023
7d55106
stack
qgallouedec Feb 2, 2023
e874585
Merge branch 'master' into new-stack-observation
araffin Feb 6, 2023
b9b46ea
Remove n_stack
araffin Feb 6, 2023
80a3e5c
Merge branch 'master' into new-stack-observation
araffin Feb 6, 2023
b3dff6c
Update changelog
araffin Feb 6, 2023
bbee2ea
Simplify code
araffin Feb 6, 2023
c03f60a
Rename test file
araffin Feb 6, 2023
a677405
Re-use variable for shift
araffin Feb 6, 2023
f01a2be
Fix doc build
araffin Feb 6, 2023
424fde8
Remove pytype comment
qgallouedec Feb 6, 2023
0c2917c
Disable pytype error
araffin Feb 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ StackedObservations
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations
:members:

StackedDictObservations
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations
:members:

VecNormalize
~~~~~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ Changelog
==========


Release 1.8.0a3 (WIP)
Release 1.8.0a4 (WIP)
--------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)

New Features:
^^^^^^^^^^^^^
Expand All @@ -36,6 +37,7 @@ Others:
- Fixed ``tests/test_tensorboard.py`` type hint
- Fixed ``tests/test_vec_normalize.py`` type hint
- Fixed ``stable_baselines3/common/monitor.py`` type hint
- Added tests for StackedObservations

Documentation:
^^^^^^^^^^^^^^
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ exclude = (?x)(
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
| stable_baselines3/common/vec_env/stacked_observations.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
Expand Down
3 changes: 1 addition & 2 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
Expand Down Expand Up @@ -78,7 +78,6 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
"VecEnv",
"VecEnvWrapper",
"DummyVecEnv",
"StackedDictObservations",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",
Expand Down
317 changes: 129 additions & 188 deletions stable_baselines3/common/vec_env/stacked_observations.py

Large diffs are not rendered by default.

47 changes: 12 additions & 35 deletions stable_baselines3/common/vec_env/vec_frame_stack.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,40 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
from gym import spaces

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations


class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment. Designed for image observations.

Uses the StackedObservations class, or StackedDictObservations depending on the observations space

:param venv: the vectorized environment to wrap
:param venv: Vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
"""

def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
self.venv = venv
self.n_stack = n_stack

wrapped_obs_space = venv.observation_space

if isinstance(wrapped_obs_space, spaces.Box):
assert not isinstance(
channels_order, dict
), f"Expected None or string for channels_order but received {channels_order}"
self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)

elif isinstance(wrapped_obs_space, spaces.Dict):
self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)

else:
raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
assert isinstance(
venv.observation_space, (spaces.Box, spaces.Dict)
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"

observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
observation_space = self.stacked_obs.stacked_observation_space
super().__init__(venv, observation_space=observation_space)

def step_wait(
self,
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
observations, rewards, dones, infos = self.venv.step_wait()

observations, infos = self.stackedobs.update(observations, dones, infos)

observations, infos = self.stacked_obs.update(observations, dones, infos)
return observations, rewards, dones, infos

def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch

observation = self.stackedobs.reset(observation)
observation = self.stacked_obs.reset(observation)
return observation

def close(self) -> None:
self.venv.close()
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0a3
1.8.0a4
Loading