Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.1.0a2 (WIP)
Release 1.1.0a3 (WIP)
---------------------------

Breaking Changes:
Expand All @@ -21,6 +21,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -649,4 +650,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis
4 changes: 3 additions & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,9 @@ def load(
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
recursive_setattr(model, name, pytorch_variables[name])
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, name + ".data", pytorch_variables[name].data)

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0a2
1.1.0a3
17 changes: 17 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,23 @@ def test_exclude_include_saved_params(tmp_path, model_class):
os.remove(tmp_path / "test_save.zip")


def test_save_load_pytorch_var(tmp_path):
model = SAC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
model.save(str(tmp_path / "sac_pendulum"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit, but inclined to split this as:

Suggested change
model.save(str(tmp_path / "sac_pendulum"))
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)

and then reuse save_path on line 245 below

env = model.get_env()
ent_coef_before = model.log_ent_coef

del model

model = SAC.load(str(tmp_path / "sac_pendulum"), env=env)
assert th.allclose(ent_coef_before, model.log_ent_coef)
model.learn(200)
ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(ent_coef_before, ent_coef_after)


@pytest.mark.parametrize("model_class", [A2C, TD3])
def test_save_load_env_cnn(tmp_path, model_class):
"""
Expand Down