Skip to content
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ New Features:
^^^^^^^^^^^^^
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -399,4 +400,4 @@ And all the contributors:
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273
@diditforlulz273 @liorcohen5
7 changes: 5 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,16 @@ def predict(
return self.policy.predict(observation, state, mask, deterministic)

@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
def load(
cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs
) -> "BaseAlgorithm":
"""
Load the model from a zip-file

:param load_path: the location of the saved data
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: (Union[th.device, str]) Device on which the code should run.
:param kwargs: extra arguments to change the model when loading
"""
data, params, tensors = load_from_zip_file(load_path)
Expand Down Expand Up @@ -352,7 +355,7 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAl
model = cls(
policy=data["policy_class"],
env=env,
device="auto",
device=device,
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
)

Expand Down
27 changes: 16 additions & 11 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,26 @@ def test_save_load(tmp_path, model_class):
# Check
model.save(tmp_path / "test_save.zip")
del model
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)

# check if params are still the same after load
new_params = model.policy.state_dict()
# Check if the model loads as expected for every possible choice of device:
for device in ["auto", "cpu", "cuda"]:
Copy link
Contributor Author

@leor-c leor-c Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the git code comparison looks quite messy. I'm elaborating about the changes I've made here to ease the review process for you:
The actual change that I made here is the added 'for' loop that goes over all possible devices, and at each iteration the device parameter is passed to the call of 'load' (line 76). At the end of each iteration I delete the model (line 92) so it can be loaded cleanly at the next iteration.
Everything else is the same as before, i.e., I've used the exact same test (inside the new 'for' loop) to ensure proper loading and tested with all possible values of the new argument 'device'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that you are actually not testing that the device parameter was successfully used.
Also, you should skip the cuda device if no GPU is available

Copy link
Contributor Author

@leor-c leor-c Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. You're right. I will work on improving the test.
  2. What should be the expected behavior when a user uses "device='cuda'" on a machine with no GPU?
    I noticed that the c'tor defaults to using the CPU in that case without notifying the user.
    Anyway, I think the test should include all possible inputs while verifying that the outcome matches your expectations. Do you agree?

Copy link
Contributor Author

@leor-c leor-c Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used in my test the utils.get_device() function (which is used inside the constructor as well) to determine the device. This way, if for example, the behavior of get_device will change, the test won't break.

model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)

# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if params are still the same after load
new_params = model.policy.state_dict()

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."

# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)

# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)

del model

# clear file from os
os.remove(tmp_path / "test_save.zip")
Expand Down