Skip to content

Commit 90ad116

Browse files
araffinleor-c
authored andcommitted
Fix default arguments + add bugbear (DLR-RM#363)
* Fix potential bug + add bug bear * Remove unused variables * Minor: version bump
1 parent 239dfef commit 90ad116

File tree

10 files changed

+17
-14
lines changed

10 files changed

+17
-14
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 1.1.0a0 (WIP)
7+
Release 1.1.0a1 (WIP)
88
---------------------------
99

1010
Breaking Changes:
@@ -15,13 +15,14 @@ New Features:
1515

1616
Bug Fixes:
1717
^^^^^^^^^^
18-
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
18+
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
1919

2020
Deprecations:
2121
^^^^^^^^^^^^^
2222

2323
Others:
2424
^^^^^^^
25+
- Added ``flake8-bugbear`` to tests dependencies to find likely bugs
2526

2627
Documentation:
2728
^^^^^^^^^^^^^^

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
Repository:
2121
https://github.com/DLR-RM/stable-baselines3
2222
23-
Medium article:
24-
https://medium.com/@araffin/df87c4b2fc82
23+
Blog post:
24+
https://araffin.github.io/post/sb3/
2525
2626
Documentation:
2727
https://stable-baselines3.readthedocs.io/en/master/
@@ -94,6 +94,8 @@
9494
"pytype",
9595
# Lint code
9696
"flake8>=3.8",
97+
# Find likely bugs
98+
"flake8-bugbear",
9799
# Sort imports
98100
"isort>=5.0",
99101
# Reformat

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
env: Union[GymEnv, str],
7777
policy_base: Type[BasePolicy],
7878
learning_rate: Union[float, Schedule],
79-
buffer_size: int = int(1e6),
79+
buffer_size: int = 1000000,
8080
learning_starts: int = 100,
8181
batch_size: int = 256,
8282
tau: float = 0.005,

stable_baselines3/common/results_plotter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def plot_curves(
8484
plt.figure(title, figsize=figsize)
8585
max_x = max(xy[0][-1] for xy in xy_list)
8686
min_x = 0
87-
for (i, (x, y)) in enumerate(xy_list):
87+
for (_, (x, y)) in enumerate(xy_list):
8888
plt.scatter(x, y, s=2)
8989
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
9090
if x.shape[0] >= EPISODES_WINDOW:

stable_baselines3/common/torch_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(
170170
last_layer_dim_shared = feature_dim
171171

172172
# Iterate through the shared layers and build the shared parts of the network
173-
for idx, layer in enumerate(net_arch):
173+
for layer in net_arch:
174174
if isinstance(layer, int): # Check that this is a shared layer
175175
layer_size = layer
176176
# TODO: give layer a meaningful name
@@ -192,7 +192,7 @@ def __init__(
192192
last_layer_dim_vf = last_layer_dim_shared
193193

194194
# Build the non-shared part of the network
195-
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
195+
for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
196196
if pi_layer_size is not None:
197197
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
198198
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))

stable_baselines3/ddpg/ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
policy: Union[str, Type[TD3Policy]],
5555
env: Union[GymEnv, str],
5656
learning_rate: Union[float, Schedule] = 1e-3,
57-
buffer_size: int = int(1e6),
57+
buffer_size: int = 1000000,
5858
learning_starts: int = 100,
5959
batch_size: int = 100,
6060
tau: float = 0.005,

stable_baselines3/dqn/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
147147
self._update_learning_rate(self.policy.optimizer)
148148

149149
losses = []
150-
for gradient_step in range(gradient_steps):
150+
for _ in range(gradient_steps):
151151
# Sample replay buffer
152152
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
153153

stable_baselines3/sac/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
policy: Union[str, Type[SACPolicy]],
7575
env: Union[GymEnv, str],
7676
learning_rate: Union[float, Schedule] = 3e-4,
77-
buffer_size: int = int(1e6),
77+
buffer_size: int = 1000000,
7878
learning_starts: int = 100,
7979
batch_size: int = 256,
8080
tau: float = 0.005,

stable_baselines3/td3/td3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
policy: Union[str, Type[TD3Policy]],
6363
env: Union[GymEnv, str],
6464
learning_rate: Union[float, Schedule] = 1e-3,
65-
buffer_size: int = int(1e6),
65+
buffer_size: int = 1000000, # 1e6
6666
learning_starts: int = 100,
6767
batch_size: int = 100,
6868
tau: float = 0.005,
@@ -131,7 +131,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
131131

132132
actor_losses, critic_losses = [], []
133133

134-
for gradient_step in range(gradient_steps):
134+
for _ in range(gradient_steps):
135135

136136
self._n_updates += 1
137137
# Sample replay buffer

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.1.0a0
1+
1.1.0a1

0 commit comments

Comments
 (0)