diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 14f1dc9af2..a24f02a38a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.1.0a4 (WIP) +Release 1.1.0a5 (WIP) --------------------------- Breaking Changes: @@ -23,6 +23,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same) - Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis) +- Fixed saving of ``A2C`` and ``PPO`` policy when using gSDE (thanks @liusida) Deprecations: ^^^^^^^^^^^^^ @@ -653,4 +654,4 @@ And all the contributors: @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn -@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis +@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index ea6fa97308..c0067bebf3 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -435,7 +435,7 @@ def _get_constructor_parameters(self) -> Dict[str, Any]: log_std_init=self.log_std_init, squash_output=default_none_kwargs["squash_output"], full_std=default_none_kwargs["full_std"], - sde_net_arch=default_none_kwargs["sde_net_arch"], + sde_net_arch=self.sde_net_arch, use_expln=default_none_kwargs["use_expln"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone ortho_init=self.ortho_init, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index cd64e7162a..c84ce18990 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a4 +1.1.0a5 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9b629ef3ac..8f71f5203d 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -341,7 +341,8 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage): @pytest.mark.parametrize("model_class", MODEL_LIST) @pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) -def test_save_load_policy(tmp_path, model_class, policy_str): +@pytest.mark.parametrize("use_sde", [False, True]) +def test_save_load_policy(tmp_path, model_class, policy_str, use_sde): """ Test saving and loading policy only. @@ -349,6 +350,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str): :param policy_str: (str) Name of the policy. """ kwargs = dict(policy_kwargs=dict(net_arch=[16])) + + # gSDE is only applicable for A2C, PPO and SAC + if use_sde and model_class not in [A2C, PPO, SAC]: + pytest.skip() + if policy_str == "MlpPolicy": env = select_env(model_class) else: @@ -360,6 +366,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str): ) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) + if use_sde: + kwargs["use_sde"] = True + env = DummyVecEnv([lambda: env]) # create model