-
Notifications
You must be signed in to change notification settings - Fork 2k
Dictionary Observations #243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
96 commits
Select commit
Hold shift + click to select a range
c8d0914
First commit
J-Travnik b7f3386
Fixing missing refs from a quick merge from master
J-Travnik 21fecd3
Reformat
araffin b2a1c14
Adding DictBuffers
J-Travnik 1cda60b
Adding DictBuffers
J-Travnik 8a04e61
Reformat
araffin 86b3c14
Minor reformat
araffin f60d439
added slow dict test. Added SACMultiInputPolicy for future. Added pri…
J-Travnik 3d2e041
Merge branch 'feat/dict_observations' of https://github.com/J-Travnik…
J-Travnik da1de6e
Ran black on buffers
J-Travnik 51249da
Ran isort
J-Travnik 761a67f
Adding StackedObservations classes used within VecStackEnvs wrappers.…
J-Travnik 3cb69f5
Running isort :facepalm
J-Travnik 82fe425
Fixed typing issues
araffin 201799d
Adding docstrings and typing. Using util for moving data to device.
J-Travnik 683bbf2
Fixed trailing commas
J-Travnik 887e007
Merging pull of previous format
J-Travnik 15ceb35
Fix types
araffin f9cab8a
Minor edits
araffin 5b178f4
Avoid duplicating code
araffin d692027
Fix calls to parents
araffin b5249ec
Merge branch 'master' into feat/dict_observations
araffin 8b22f96
Adding assert to buffers. Updating changelong
J-Travnik 70dfa83
Running format on buffers
J-Travnik b2b7d6f
Merge branch 'master' into feat/dict_observations
araffin a006b5a
Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bu…
J-Travnik a94e6df
Merge branch 'master' into feat/dict_observations
araffin 12361ae
Fixing warnings, splitting is_vectorized_observation into multiple fu…
J-Travnik 9c6390b
Merge branch 'feat/dict_observations' of https://github.com/J-Travnik…
J-Travnik 51cb4e4
Merge branch 'master' into feat/dict_observations
araffin ce0f1a4
Created envs folder in common. Updated imports. Moved stacked_obs to …
J-Travnik 9eee82a
Moved envs to envs directory. Moved stacked obs to vec_envs. Started …
J-Travnik c6a8705
Merge branch 'master' into feat/dict_observations
araffin c3d2138
Fixes
araffin c893faa
Merged with master. Added miniscule delay to prevent zero divide on d…
J-Travnik 935cef9
Running code style
J-Travnik a07497b
Update docstrings on torch_layers
Miffyli 96d1e64
Decapitalize non-constant variables
Miffyli 245f4ab
Using NatureCNN architecture in combined extractor. Increasing img si…
J-Travnik 715fec8
merged with latest
J-Travnik 4dc1625
Update doc
araffin 57c1926
Update doc
araffin 0fa3650
Fix format
araffin 20b217a
Merge branch 'master' into feat/dict_observations
araffin f6ab0bc
Merge branch 'master' into feat/dict_observations
araffin 6206b36
Removing NineRoom env. Using nested preprocess. Removing mutable defa…
J-Travnik f064972
running code style
J-Travnik 90d2577
Passing channel check through to stacked dict observations.
J-Travnik 8f37cb2
Running black
J-Travnik 2984756
Adding channel control to SimpleMultiObsEnv. Passing check_channels t…
J-Travnik 324ef43
Remove optimize memory for dict buffers
araffin 2fdcfc6
Update doc
araffin 2bab0a3
Move identity env
araffin b1ec40d
Minor edits + bump version
araffin 12d42e9
Update doc
araffin 5f45044
Fix doc build
araffin 510821b
Bug fixes + add support for more type of dict env
araffin 0b09976
Merge branch 'master' into feat/dict_observations
araffin 8d9183f
Fixes + add multi env test
araffin 04170df
Merge branch 'master' into feat/dict_observations
araffin b9c4f05
Merge branch 'master' into feat/dict_observations
araffin 3bb747a
Add support for vectranspose
Miffyli cda8c21
Fix stacked obs for dict and add tests
Miffyli f770217
Add check for nested spaces. Fix dict-subprocvecenv test
Miffyli 5cbde19
Fix (single) pytype error
Miffyli 4464744
Simplify CombinedExtractor
Miffyli 1f5553a
Fix tests
araffin dda6990
Merge branch 'master' into feat/dict_observations
araffin 6716567
Fix check
araffin e756793
Merge branch 'master' into feat/dict_observations
araffin 32b899f
Fix for net_arch with dict and vector obs
araffin 6652df3
Merge branch 'master' into feat/dict_observations
araffin ec3356e
Fixes
araffin a369bb1
Merge branch 'master' into feat/dict_observations
araffin 4f787fa
Add consistency test
araffin e945ec1
Update env checker
araffin 12c8be0
Merge branch 'master' into feat/dict_observations
araffin a4851b1
Merge branch 'master' into feat/dict_observations
araffin 4138f96
Add some docs on dict obs
Miffyli 4f12135
Merge branch 'master' into feat/dict_observations
araffin 613a141
Merge branch 'master' into feat/dict_observations
araffin 0bcfa11
Update default CNN feature vector size
Miffyli 10f4f6b
Merge branch 'master' into feat/dict_observations
araffin 652a6d0
Refactor HER (#351)
araffin bcd97cd
Merge remote-tracking branch 'github/tmp/dict-obs' into feat/dict_obs…
araffin 89607cf
Update doc and minor fixes
araffin 495cf5d
Update doc
araffin 0ea5c61
Added note about MultiInputPolicy in error of NatureCNN
J-Travnik f8351ab
Merge branch 'master' into feat/dict_observations
araffin 94cb760
Address comments
Miffyli 5d56c34
Naming clarifications
Miffyli ab75dcd
merge master
Miffyli c30916e
Actually saving the file would be nice
Miffyli 0acea97
Fix edge case when doing online sampling with HER
araffin d6a59f9
Cleanup
araffin ce848fb
Add sanity check
araffin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| import argparse | ||
J-Travnik marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import gym | ||
| import numpy as np | ||
|
|
||
| from stable_baselines3 import PPO, SAC | ||
| from stable_baselines3.common.policies import MultiInputActorCriticPolicy | ||
| from stable_baselines3.common.vec_env import ( | ||
| DummyVecEnv, | ||
| VecFrameStack, | ||
| VecTransposeImage, | ||
| ) | ||
|
|
||
| from stable_baselines3.common.multi_input_envs import ( | ||
| SimpleMultiObsEnv, | ||
| NineRoomMultiObsEnv, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Runs the multi_input_tests script") | ||
| parser.add_argument( | ||
| "--timesteps", | ||
| type=int, | ||
| default=30000, | ||
| help="Number of timesteps to train for (default: 20000)", | ||
| ) | ||
| parser.add_argument( | ||
| "--num_envs", | ||
| type=int, | ||
| default=10, | ||
| help="Number of environments to use (default: 10)", | ||
| ) | ||
| parser.add_argument( | ||
| "--frame_stacks", | ||
| type=int, | ||
| default=1, | ||
| help="Number of stacked frames to use (default: 4)", | ||
| ) | ||
| parser.add_argument( | ||
| "--room9", | ||
| action="store_true", | ||
| help="If true, uses more complex 9 room environment", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| ENV_CLS = NineRoomMultiObsEnv if args.room9 else SimpleMultiObsEnv | ||
|
|
||
| make_env = lambda: ENV_CLS(random_start=True) | ||
|
|
||
| env = DummyVecEnv([make_env for i in range(args.num_envs)]) | ||
| if args.frame_stacks > 1: | ||
| env = VecFrameStack(env, n_stack=args.frame_stacks) | ||
|
|
||
| model = PPO(MultiInputActorCriticPolicy, env) | ||
|
|
||
| model.learn(args.timesteps) | ||
| env.close() | ||
| print("Done training, starting testing") | ||
|
|
||
| make_env = lambda: ENV_CLS(random_start=False) | ||
| test_env = DummyVecEnv([make_env]) | ||
| if args.frame_stacks > 1: | ||
| test_env = VecFrameStack(test_env, n_stack=args.frame_stacks) | ||
|
|
||
| obs = test_env.reset() | ||
| num_episodes = 1 | ||
| trajectories = [[]] | ||
| i_step, i_episode = 0, 0 | ||
| while i_episode < num_episodes: | ||
| action, _states = model.predict(obs, deterministic=False) | ||
| obs, reward, done, info = test_env.step(action) | ||
| test_env.render() | ||
| trajectories[-1].append((test_env.get_attr("state")[0], action[0])) | ||
|
|
||
| i_step += 1 | ||
|
|
||
| if done[0]: | ||
| if info[0]["got_to_end"]: | ||
| print(f"Episode {i_episode} : Got to end in {i_step} steps") | ||
| else: | ||
| print(f"Episode {i_episode} : Did not get to end") | ||
| obs = test_env.reset() | ||
| i_step = 0 | ||
| trajectories.append([]) | ||
| i_episode += 1 | ||
|
|
||
| test_env.close() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.