Skip to content

Commit

Permalink
Merge pull request #123 from edbeeching/sb3-multi-env
Browse files Browse the repository at this point in the history
Adds multiple env support to the SB3 wrapper
  • Loading branch information
edbeeching authored Jul 13, 2023
2 parents 4802865 + d215697 commit 85b787a
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 89 deletions.
13 changes: 9 additions & 4 deletions docs/ADV_STABLE_BASELINES_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ pip install godot-rl[sb3]
## Basic Environment Usage
Usage instructions for envs **BallChase**, **FlyBy** and **JumperHard.**

Download the env:
### Download the env:

```bash
gdrl.env_from_hub -r edbeeching/godot_rl_<ENV_NAME>
chmod +x examples/godot_rl_<ENV_NAME>/bin/<ENV_NAME>.x86_64 # linux example
```

Train a model from scratch:
### Train a model from scratch:

```bash
gdrl --env=gdrl --env_path=examples/godot_rl_<ENV_NAME>/bin/<ENV_NAME>.x86_64 --viz
```

While the default options for sb3 work reasonably well. You may be interested in changing the hyperparameters.

We recommend taking the [sb3 example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/stable_baselines3_example.py) and modifying to match your needs.
We recommend taking the [sb3 example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/stable_baselines3_example.py) and modifying to match your needs.

This example exposes more parameter for the user to configure, such as `--speedup` to run the environment faster than realtime and the `n_parallel` to launch several instances of the game executable in order to accelerate training (not available for in-editor training).


```python
import argparse

Expand All @@ -66,11 +70,12 @@ parser.add_argument(
)

parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env")
parser.add_argument("--n_parallel", default=1, type=int, help="whether to speed up the physics in the env")

args, extras = parser.parse_known_args()


env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup, convert_action_space=True)
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel, speedup=args.speedup)

model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log="logs/log")
model.learn(200000)
Expand Down
6 changes: 3 additions & 3 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
help="The Godot binary to use, do not include for in editor training",
)

parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env")

parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to launch - requires --env_path to be set if > 1.")
args, extras = parser.parse_known_args()


env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup)
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel, speedup=args.speedup)
env = VecMonitor(env)

model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log="logs/log")
Expand Down
143 changes: 108 additions & 35 deletions godot_rl/core/godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,40 @@

import numpy as np
from gym import spaces

from typing import Optional
from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path


class GodotEnv:
MAJOR_VERSION = "0"
MINOR_VERSION = "3"
DEFAULT_PORT = 11008
DEFAULT_TIMEOUT = 60
MAJOR_VERSION = "0" # Versioning for the environment
MINOR_VERSION = "4"
DEFAULT_PORT = 11008 # Default port for communication with Godot Game
DEFAULT_TIMEOUT = 60 # Default socket timeout TODO

def __init__(
self,
env_path=None,
port=DEFAULT_PORT,
show_window=False,
seed=0,
framerate=None,
action_repeat=None,
speedup=None,
convert_action_space=False,
env_path: str=None,
port: int=DEFAULT_PORT,
show_window: bool=False,
seed:int=0,
framerate:Optional[int]=None,
action_repeat:Optional[int]=None,
speedup:Optional[int]=None,
convert_action_space:bool=False,
):
"""
Initialize a new instance of GodotEnv
Args:
env_path (str): path to the godot binary environment.
port (int): Port number for communication.
show_window (bool): flag to display Godot game window.
seed (int): seed to initialize the environment.
framerate (int): the framerate to run the Godot game at.
action_repeat (int): the number of frames to repeat an action for.
speedup (int): the factor to speedup game time by.
convert_action_space (bool): flag to convert action space.
"""

self.proc = None
if env_path is not None and env_path != "debug":
Expand All @@ -51,7 +64,16 @@ def __init__(

atexit.register(self._close)

def _set_platform_suffix(self, env_path):
def _set_platform_suffix(self, env_path: str) -> str:
"""
Set the platform suffix for the given environment path based on the platform.
Args:
env_path (str): The environment path.
Returns:
str: The environment path with the platform suffix.
"""
suffixes = {
"linux": ".x86_64",
"linux2": ".x86_64",
Expand All @@ -62,6 +84,15 @@ def _set_platform_suffix(self, env_path):
return str(pathlib.Path(env_path).with_suffix(suffix))

def check_platform(self, filename: str):
"""
Check the platform and assert the file type
Args:
filename (str): Path of the file to check.
Raises:
AssertionError: If the file type does not match with the platform or file does not exist.
"""
if platform == "linux" or platform == "linux2":
# Linux
assert (
Expand All @@ -83,7 +114,16 @@ def check_platform(self, filename: str):
assert os.path.exists(filename)

def from_numpy(self, action, order_ij=False):
# handles dict to tuple actions
"""
Handles dict to tuple actions
Args:
action: The action to be converted.
order_ij (bool): Order flag.
Returns:
list: The converted action.
"""
result = []

for i in range(self.num_envs):
Expand All @@ -104,14 +144,43 @@ def from_numpy(self, action, order_ij=False):
return result

def step(self, action, order_ij=False):
"""
Perform one step in the environment.
Args:
action: Action to be taken.
order_ij (bool): Order flag.
Returns:
tuple: Tuple containing observation, reward, done flag, termination flag, and info.
"""
self.step_send(action, order_ij=order_ij)
return self.step_recv()


def step_send(self, action, order_ij=False):
"""
Send the action to the Godot environment.
Args:
action: Action to be sent.
order_ij (bool): Order flag.
"""
action = self.action_space_processor.to_original_dist(action)
message = {
"type": "action",
"action": self.from_numpy(action, order_ij=order_ij),
}
self._send_as_json(message)

def step_recv(self):
"""
Receive the step response from the Godot environment.
Returns:
tuple: Tuple containing observation, reward, done flag, termination flag, and info.
"""
response = self._get_json_dict()

response["obs"] = self._process_obs(response["obs"])

return (
Expand All @@ -121,17 +190,33 @@ def step(self, action, order_ij=False):
np.array(response["done"]).tolist(), # TODO update API to term, trunc
[{}] * len(response["done"]),
)



def _process_obs(self, response_obs: dict):
"""
Process observation data.
Args:
response_obs (dict): The response observation to be processed.
Returns:
dict: The processed observation data.
"""
for k in response_obs[0].keys():
if "2d" in k:
for sub in response_obs:
sub[k] = self.decode_2d_obs_from_string(sub[k], self.observation_space[k].shape)
sub[k] = self._decode_2d_obs_from_string(sub[k], self.observation_space[k].shape)

return response_obs

def reset(self, seed=None):
"""
Reset the Godot environment.
Returns:
dict: The initial observation data.
"""
message = {
"type": "reset",
}
Expand Down Expand Up @@ -165,6 +250,10 @@ def close(self):
except Exception as e:
print("exception unregistering close method", e)

@property
def action_space(self):
return self.action_space_processor.action_space

def _close(self):
print("exit was not clean, using atexit to close env")
self.close()
Expand Down Expand Up @@ -254,31 +343,17 @@ def _get_env_info(self):
)
elif v["space"] == "discrete":
observation_spaces[k] = spaces.Discrete(v["size"])
# elif v["space"] == "repeated": TODO: Add repeated spaces back when we have support and a good example
# assert "max_length" in v
# if v["subspace"] == "box":
# subspace = observation_spaces[k] = spaces.Box(
# low=-1.0,
# high=1.0,
# shape=v["size"],
# dtype=np.float32,
# )
# elif v["subspace"] == "discrete":
# subspace = spaces.Discrete(v["size"])
# observation_spaces[k] = Repeated(subspace, v["max_length"])
else:
print(f"observation space {v['space']} is not supported")
assert 0, f"observation space {v['space']} is not supported"
self.observation_space = spaces.Dict(observation_spaces)

self.num_envs = json_dict["n_agents"]

@property
def action_space(self):
return self.action_space_processor.action_space


@staticmethod
def decode_2d_obs_from_string(
def _decode_2d_obs_from_string(
hex_string,
shape,
):
Expand All @@ -300,15 +375,13 @@ def _get_obs(self):
return self._get_data()

def _clear_socket(self):

self.connection.setblocking(False)
try:
while True:
data = self.connection.recv(4)
if not data:
break
except BlockingIOError as e:
# print("BlockingIOError expection on clear")
pass
self.connection.setblocking(True)

Expand Down
6 changes: 3 additions & 3 deletions godot_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
try:
from godot_rl.wrappers.ray_wrapper import rllib_training
except ImportError as e:
print("Error: ", e)
print("Warning: ", e)
def rllib_training(args, extras):
print("Import error when trying to use rllib. If you have not installed the package, try: pip install godot-rl[rllib]")
print("Otherwise try fixing the error above.")
Expand All @@ -34,7 +34,7 @@ def rllib_training(args, extras):
try:
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training
except ImportError as e:
print("Error: ", e)
print("Warning: ", e)
def stable_baselines_training(args, extras):
print(
"Import error when trying to use sb3. If you have not installed the package, try: pip install godot-rl[sb3]"
Expand All @@ -44,7 +44,7 @@ def stable_baselines_training(args, extras):
try:
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy
except ImportError as e:
print("Error: ", e)
print("Warning: ", e)
def sample_factory_training(args, extras):
print(
"Import error when trying to use sample-factory If you have not installed the package, try: pip install godot-rl[sf]"
Expand Down
Loading

0 comments on commit 85b787a

Please sign in to comment.