-
Notifications
You must be signed in to change notification settings - Fork 2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Error loading saved model due to observation space mismatch between model and env.
Reproduction
import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy
class CustomEnv(gym.Env):
def __init__(self):
super(CustomEnv, self).__init__()
self.reward_range = (-1, 1)
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(36, 36, 4), dtype=np.uint8)
def reset(self):
return np.ones((36, 36, 4))
def step(self, action):
return np.ones((36, 36, 4)), 0, True, {}
custom_env = CustomEnv()
model = PPO(CnnPolicy, custom_env)
print(custom_env.observation_space)
print(model.observation_space)
model.save("custom_env")
model = PPO.load("custom_env", custom_env)
Output and Stack Trace
Box(36, 36, 4)
Box(4, 36, 36)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-176-a54409b23a82> in <module>
1 model.save("custom_env")
----> 2 model = PPO.load("custom_env", custom_env)
/opt/conda/lib/python3.7/site-packages/stable_baselines3/common/base_class.py in load(cls, path, env, device, **kwargs)
595 cls._wrap_env(env, data["verbose"])
596 # Check if given env is valid
--> 597 check_for_correct_spaces(env, data["observation_space"], data["action_space"])
598 else:
599 # Use stored env, if one exists. If not, continue as is (can be used for predict)
/opt/conda/lib/python3.7/site-packages/stable_baselines3/common/utils.py in check_for_correct_spaces(env, observation_space, action_space)
204 """
205 if observation_space != env.observation_space:
--> 206 raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
207 if action_space != env.action_space:
208 raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
ValueError: Observation spaces do not match: Box(4, 36, 36) != Box(36, 36, 4)
System Info
- Library was installed from source
- Python version 3.7.6
- PyTorch version 1.5.1
- Gym version 0.17.2
yijionglin and mdiephuis
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working