-
Notifications
You must be signed in to change notification settings - Fork 2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Bug
Usually we use model.save(path) to save a zip file.
But policy also has a save method. If one call that method with gSDE enabled, it will produce an error.
To Reproduce
import pybullet_envs
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'HopperBulletEnv-v0', use_sde=True)
model.learn(100)
model.policy.save("/tmp/sde_policy")Traceback (most recent call last):
File "try_sde_with_policy_save.py", line 6, in <module>
model.policy.save("/tmp/sde_policy")
File "/home/liusida/code/code_trysb3/stable-baselines3/stable_baselines3/common/policies.py", line 152, in save
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
File "/home/liusida/code/code_trysb3/stable-baselines3/stable_baselines3/common/policies.py", line 438, in _get_constructor_parameters
sde_net_arch=default_none_kwargs["sde_net_arch"],
KeyError: 'sde_net_arch'Expected behavior
The policy (the PyTorch module) should be saved as a file.
### System Info
Describe the characteristic of your environment:
- tested with pip install and also commit: 5d47296
- GPU models and configuration: No GPU
- Python version: 3.8
- PyTorch version: 1.8.1
- Gym version: 0.18.0
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working