-
Notifications
You must be signed in to change notification settings - Fork 221
Add TRPO #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add TRPO #40
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
f779a9f
Feat: adding TRPO algorithm (WIP)
cyprienc 98bc5b2
Feat: adding TRPO algorithm (WIP)
cyprienc 97ece67
Feat: adding TRPO algorithm (WIP)
cyprienc 799b140
Feat: adding TRPO algorithm (WIP)
cyprienc dc73462
Feat: adding TRPO algorithm (WIP)
cyprienc 9b8a222
feat: TRPO - addressing PR comments
cyprienc 869dce9
refactor: TRPO - policier
cyprienc 347dcc0
feat: using updated ActorCriticPolicy from SB3
cyprienc 35d7256
Bump version for `get_distribution` support
araffin 9cfcb54
Add basic test
araffin 974174a
Reformat
araffin b6bd449
[ci skip] Fix changelog
araffin c88951c
fix: setting train mode for trpo
cyprienc 1f7e99d
fix: batch_size type hint in trpo.py
cyprienc 6540371
style: renaming variables + docstring in trpo.py
cyprienc 3a26c05
Merge branch 'master' into master
araffin f003e88
Merge branch 'master' into master
araffin a33409e
Merge branch 'master' into master
araffin 8ecf40e
Rename + cleanup
araffin 45f4ea6
Move grad computation to separate method
araffin cc4b5ab
Remove grad norm clipping
araffin fc7a6c7
Remove n epochs and add sub-sampling
araffin 66723ff
Update defaults
araffin 63a263f
Merge branch 'master' into master
araffin bf583de
Merge branch 'master' into cyprienc/master
araffin e983348
Add Doc
araffin 439d79b
Add more test and fixes for CNN
araffin d9483dc
Update doc + add benchmark
araffin fff84e4
Add tests + update doc
araffin 95dddf4
Fix doc
araffin 661fe15
Improve names for conjugate gradient
araffin a24e7c0
Update comments
araffin 342fe53
Update changelog
araffin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| """Policies: abstract base class and concrete implementations.""" | ||
|
|
||
| from stable_baselines3.common.distributions import Distribution | ||
| from stable_baselines3.common.policies import ActorCriticPolicy as _ActorCriticPolicy | ||
|
|
||
|
|
||
| class ActorCriticPolicy(_ActorCriticPolicy): | ||
| """ | ||
| Policy class for actor-critic algorithms (has both policy and value prediction). | ||
| Used by A2C, PPO and the likes. | ||
| """ | ||
|
|
||
| def get_distribution(self) -> Distribution: | ||
| """ | ||
| Get the current action distribution | ||
| :return: Action distribution | ||
| """ | ||
| return self.action_dist | ||
|
|
||
|
|
||
| # This is just to propagate get_distribution | ||
| class ActorCriticCnnPolicy(ActorCriticPolicy): | ||
| pass | ||
|
|
||
|
|
||
| # This is just to propagate get_distribution | ||
| class MultiInputActorCriticPolicy(ActorCriticPolicy): | ||
| pass | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy | ||
| from sb3_contrib.trpo.trpo import TRPO |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # This file is here just to define MlpPolicy/CnnPolicy | ||
| # that work for TRPO | ||
| from sb3_contrib.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy | ||
| from stable_baselines3.common.policies import register_policy | ||
|
|
||
|
|
||
| MlpPolicy = ActorCriticPolicy | ||
| CnnPolicy = ActorCriticCnnPolicy | ||
| MultiInputPolicy = MultiInputActorCriticPolicy | ||
|
|
||
| register_policy("MlpPolicy", ActorCriticPolicy) | ||
| register_policy("CnnPolicy", ActorCriticCnnPolicy) | ||
| register_policy("MultiInputPolicy", MultiInputPolicy) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.