Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
Changelog
==========


Pre-Release 0.10.0a0 (WIP)
------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Improved typing coverage
- Improved error messages for unsupported spaces

Documentation:
^^^^^^^^^^^^^^


Pre-Release 0.9.0 (2020-10-03)
------------------------------

Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
except ImportError:
cv2 = None

from stable_baselines3.common.type_aliases import GymStepReturn
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn


class NoopResetEnv(gym.Wrapper):
Expand Down Expand Up @@ -146,7 +146,7 @@ def step(self, action: int) -> GymStepReturn:

return max_frame, total_reward, done, info

def reset(self, **kwargs):
def reset(self, **kwargs) -> GymObs:
return self.env.reset(**kwargs)


Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def set_parameters(
load_path_or_dict: Union[str, Dict[str, Dict]],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
):
) -> None:
"""
Load parameters from a given zip-file or a nested dictionary containing parameters for
different modules (see ``get_parameters``).
Expand Down Expand Up @@ -610,7 +610,7 @@ def load(
model.policy.reset_noise() # pytype: disable=attribute-error
return model

def get_parameters(self):
def get_parameters(self) -> Dict[str, Dict]:
"""
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).
Expand Down
8 changes: 6 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from abc import ABC, abstractmethod
from typing import Generator, Optional, Union

import numpy as np
Expand All @@ -16,7 +17,7 @@
from stable_baselines3.common.vec_env import VecNormalize


class BaseBuffer(object):
class BaseBuffer(ABC):
"""
Base class that represent a buffer (rollout or replay)

Expand Down Expand Up @@ -102,7 +103,10 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)

def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None):
@abstractmethod
def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
"""
:param batch_inds:
:param env:
Expand Down
11 changes: 6 additions & 5 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import gym
import numpy as np
Expand Down Expand Up @@ -217,9 +217,10 @@ class CheckpointCallback(BaseCallback):
:param save_freq:
:param save_path: Path to the folder where the model will be saved.
:param name_prefix: Common prefix to the saved models
:param verbose:
"""

def __init__(self, save_freq: int, save_path: str, name_prefix="rl_model", verbose=0):
def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
super(CheckpointCallback, self).__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
Expand Down Expand Up @@ -247,7 +248,7 @@ class ConvertCallback(BaseCallback):
:param verbose:
"""

def __init__(self, callback, verbose=0):
def __init__(self, callback: Callable, verbose: int = 0):
super(ConvertCallback, self).__init__(verbose)
self.callback = callback

Expand Down Expand Up @@ -314,7 +315,7 @@ def __init__(
self.evaluations_timesteps = []
self.evaluations_length = []

def _init_callback(self):
def _init_callback(self) -> None:
# Does not work in some corner cases, where the wrapper is not the same
if not isinstance(self.training_env, type(self.eval_env)):
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
Expand Down Expand Up @@ -450,7 +451,7 @@ def __init__(self, max_episodes: int, verbose: int = 0):
self._total_max_episodes = max_episodes
self.n_episodes = 0

def _init_callback(self):
def _init_callback(self) -> None:
# At start set total max according to number of envirnments
self._total_max_episodes = self.max_episodes * self.training_env.num_envs

Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/cmd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv


def make_vec_env(
Expand All @@ -19,7 +19,7 @@ def make_vec_env(
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
):
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
Expand Down Expand Up @@ -85,7 +85,7 @@ def make_atari_env(
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
):
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Probability distributions."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import gym
import torch as th
Expand All @@ -19,7 +19,7 @@ def __init__(self):
super(Distribution, self).__init__()

@abstractmethod
def proba_distribution_net(self, *args, **kwargs):
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
"""Create the layers and parameters that represent the distribution.

Subclasses must define this, but the arguments and return type vary between
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SeqWriter(object):
sequence writer
"""

def write_sequence(self, sequence: List):
def write_sequence(self, sequence: List) -> None:
"""
write_sequence an array to file

Expand Down
10 changes: 6 additions & 4 deletions stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import os
import time
from glob import glob
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import gym
import numpy as np
import pandas

from stable_baselines3.common.type_aliases import GymObs, GymStepReturn


class Monitor(gym.Wrapper):
"""
Expand Down Expand Up @@ -62,7 +64,7 @@ def __init__(
self.total_steps = 0
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()

def reset(self, **kwargs) -> np.ndarray:
def reset(self, **kwargs) -> GymObs:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True

Expand All @@ -83,7 +85,7 @@ def reset(self, **kwargs) -> np.ndarray:
self.current_reset_info[key] = value
return self.env.reset(**kwargs)

def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]:
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
"""
Step the environment with the given action

Expand Down Expand Up @@ -112,7 +114,7 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, A
self.total_steps += 1
return observation, reward, done, info

def close(self):
def close(self) -> None:
"""
Closes the environment
"""
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def base_noise(self) -> ActionNoise:
return self._base_noise

@base_noise.setter
def base_noise(self, base_noise: ActionNoise):
def base_noise(self, base_noise: ActionNoise) -> None:
if base_noise is None:
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
if not isinstance(base_noise, ActionNoise):
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
model.to(device)
return model

def load_from_vector(self, vector: np.ndarray):
def load_from_vector(self, vector: np.ndarray) -> None:
"""
Load parameters from a 1D vector.

Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_im
return obs.float()

else:
raise NotImplementedError()
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")


def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
Expand All @@ -100,7 +100,7 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
# Number of binary features
return (int(observation_space.n),)
else:
raise NotImplementedError()
raise NotImplementedError(f"{observation_space} observation space is not supported")


def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
Expand Down Expand Up @@ -139,4 +139,4 @@ def get_action_dim(action_space: spaces.Space) -> int:
# Number of binary actions
return int(action_space.n)
else:
raise NotImplementedError()
raise NotImplementedError(f"{action_space} action space is not supported")
17 changes: 10 additions & 7 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No


@functools.singledispatch
def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None):
def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None):
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
Expand All @@ -197,6 +197,7 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
if not isinstance(path, io.BufferedIOBase):
raise TypeError("Path parameter has invalid type.", io.BufferedIOBase)
Expand All @@ -214,7 +215,7 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb


@open_path.register(str)
def open_path_str(path: str, mode: str, verbose=0, suffix=None) -> io.BufferedIOBase:
def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
Expand All @@ -226,12 +227,13 @@ def open_path_str(path: str, mode: str, verbose=0, suffix=None) -> io.BufferedIO
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
return open_path(pathlib.Path(path), mode, verbose, suffix)


@open_path.register(pathlib.Path)
def open_path_pathlib(path: pathlib.Path, mode: str, verbose=0, suffix=None) -> io.BufferedIOBase:
def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
Expand All @@ -244,6 +246,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose=0, suffix=None) ->
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
if mode not in ("w", "r"):
raise ValueError("Expected mode to be either 'w' or 'r'.")
Expand Down Expand Up @@ -286,7 +289,7 @@ def save_to_zip_file(
data: Dict[str, Any] = None,
params: Dict[str, Any] = None,
pytorch_variables: Dict[str, Any] = None,
verbose=0,
verbose: int = 0,
) -> None:
"""
Save model data to a zip archive.
Expand Down Expand Up @@ -321,7 +324,7 @@ def save_to_zip_file(
archive.writestr("_stable_baselines3_version", stable_baselines3.__version__)


def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0) -> None:
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
"""
Save an object to path creating the necessary folders along the way.
If the path exists and is a directory, it will raise a warning and rename the path.
Expand All @@ -337,7 +340,7 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=
pickle.dump(obj, file_handler)


def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) -> Any:
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
"""
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
If the path does not exist, it will attempt to load using the .pkl suffix.
Expand All @@ -355,7 +358,7 @@ def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
device: Union[th.device, str] = "auto",
verbose=0,
verbose: int = 0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
"""
Load model data from a .zip archive
Expand Down
Loading