Skip to content

Save/load error due to model and env observation space mismatch #202

@kwabenantim

Description

@kwabenantim

Error loading saved model due to observation space mismatch between model and env.

Reproduction

import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy

class CustomEnv(gym.Env):    
    def __init__(self):
        super(CustomEnv, self).__init__() 
        self.reward_range = (-1, 1)
        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(36, 36, 4), dtype=np.uint8)

    def reset(self):
        return np.ones((36, 36, 4))
    
    def step(self, action):
        return np.ones((36, 36, 4)), 0, True, {}

custom_env = CustomEnv()
model = PPO(CnnPolicy, custom_env)
print(custom_env.observation_space)
print(model.observation_space)

model.save("custom_env")
model = PPO.load("custom_env", custom_env)

Output and Stack Trace

Box(36, 36, 4)
Box(4, 36, 36)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-176-a54409b23a82> in <module>
      1 model.save("custom_env")
----> 2 model = PPO.load("custom_env", custom_env)

/opt/conda/lib/python3.7/site-packages/stable_baselines3/common/base_class.py in load(cls, path, env, device, **kwargs)
    595             cls._wrap_env(env, data["verbose"])
    596             # Check if given env is valid
--> 597             check_for_correct_spaces(env, data["observation_space"], data["action_space"])
    598         else:
    599             # Use stored env, if one exists. If not, continue as is (can be used for predict)

/opt/conda/lib/python3.7/site-packages/stable_baselines3/common/utils.py in check_for_correct_spaces(env, observation_space, action_space)
    204     """
    205     if observation_space != env.observation_space:
--> 206         raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
    207     if action_space != env.action_space:
    208         raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")

ValueError: Observation spaces do not match: Box(4, 36, 36) != Box(36, 36, 4)

System Info

  • Library was installed from source
  • Python version 3.7.6
  • PyTorch version 1.5.1
  • Gym version 0.17.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions