Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.1.0a2 (WIP)
Release 1.1.0a3 (WIP)
---------------------------

Breaking Changes:
Expand All @@ -21,6 +21,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -649,4 +650,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis
4 changes: 3 additions & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,9 @@ def load(
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
recursive_setattr(model, name, pytorch_variables[name])
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, name + ".data", pytorch_variables[name].data)

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0a2
1.1.0a3
18 changes: 18 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,24 @@ def test_exclude_include_saved_params(tmp_path, model_class):
os.remove(tmp_path / "test_save.zip")


def test_save_load_pytorch_var(tmp_path):
model = SAC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
ent_coef_before = model.log_ent_coef

del model

model = SAC.load(save_path, env=env)
assert th.allclose(ent_coef_before, model.log_ent_coef)
model.learn(200)
ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(ent_coef_before, ent_coef_after)


@pytest.mark.parametrize("model_class", [A2C, TD3])
def test_save_load_env_cnn(tmp_path, model_class):
"""
Expand Down