-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds the ability to use more algorithms from sbx and sb3-contrib (with e.g. MlpPolicy) #163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, but I think the code could be more compat. The filename and class name are also enormous, it would be great to shorten them.
…_single_obs_wrapper.py
Thanks for the suggestions, I implemented the solutions and updated the return types in the file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these changes, LGTM
Thanks for the review. I'll just add the a small change to allow changing the dictionary name from obs to any value, which I found useful while testing CNN usage with the camera example, then I can merge it. |
This change adds an additional sb3 wrapper class that uses a single observation space ("obs"), like our CleanRL implementation, without modifying the original wrapper.
Algorithms like ARS from SB3 Contrib and PPO from SBX (currently doesn't have MultiInputPolicy) can be tested easily on any environment that doesn't require multiple observation spaces (which is often the case for the examples) by using this class.
Usage:
In stable_baselines_3_example.py:
First import the SingleObsSpace variant of the env wrapper:
+ from godot_rl.wrappers.sbg_single_obs_wrapper import SBGSingleObsEnv
Then (after the installation of needed packages with pip), import any algorithms to be used:
The env just needs its class name replaced to:
And then you can use e.g. the SBX PPO, SB3 Contrib ARS or any other algorithm that may not support the MultiInputPolicy:
Here's a brief try of starting testing with ARS (the env is slightly modified for some experiments and doesn't have the correct obs, but this was just an attempt to start the training, not for testing learning performance):
ars_training_test.mp4