Skip to content

Commit

Permalink
Merge pull request #134 from edbeeching/updates-sb3-version
Browse files Browse the repository at this point in the history
Updates support of sb3 to latest version
  • Loading branch information
edbeeching authored Jul 23, 2023
2 parents 8721bd5 + e6251f0 commit f90378a
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 43 deletions.
30 changes: 27 additions & 3 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,35 @@ on:
branches: [ main ]

jobs:
tests:
tests_ubuntu:
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
os: ['ubuntu-latest', 'windows-latest']
python-version: [3.8, 3.9, 3.10.10]
os: ['ubuntu-latest']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test,sf]
- name: Download examples
run: |
make download_examples
- name: Test with pytest
run: |
make test
tests_windows:
strategy:
matrix:
python-version: [3.8, 3.9, 3.10.10]
os: ['windows-latest']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
Expand Down
26 changes: 26 additions & 0 deletions examples/sample_factory_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import argparse
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy


def get_args():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--env_path", default=None, type=str, help="Godot binary to use")
parser.add_argument("--eval", default=False, action="store_true", help="whether to eval the model")
parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env")
parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model")
parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process")

return parser.parse_known_args()



def main():
args, extras = get_args()
if args.eval:
sample_factory_enjoy(args, extras)
else:
sample_factory_training(args, extras)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
parser.add_argument(
"--experiment_name",
default="Experiment",
default="experiment",
type=str,
help="The name of the experiment, which will be displayed in tensorboard",
)
Expand Down
2 changes: 1 addition & 1 deletion godot_rl/core/godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sys import platform

import numpy as np
from gym import spaces
from gymnasium import spaces
from typing import Optional
from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path

Expand Down
17 changes: 15 additions & 2 deletions godot_rl/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import gym
import numpy as np
import importlib
import re

import gymnasium as gym
import numpy as np



def lod_to_dol(lod):
return {k: [dic[k] for dic in lod] for k in lod[0]}
Expand Down Expand Up @@ -103,3 +106,13 @@ def to_original_dist(self, action):
raise NotImplementedError

return original_action

def can_import(module_name):
return not cant_import(module_name)

def cant_import(module_name):
try:
importlib.import_module(module_name)
return False
except ImportError:
return True
14 changes: 11 additions & 3 deletions godot_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,18 @@ def get_args():
parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model")
parser.add_argument("--num_gpus", default=None, type=int, help="Number of GPUs to use [only for rllib]")
parser.add_argument("--experiment_dir", default=None, type=str, help="The name of the the experiment directory, in which the tensorboard logs are getting stored")
parser.add_argument("--experiment_name", default=None, type=str, help="The name of the the experiment, which will be displayed in tensborboard")
parser.add_argument("--experiment_name", default="experiment", type=str, help="The name of the the experiment, which will be displayed in tensborboard")
parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process")

return parser.parse_known_args()

args, extras = parser.parse_known_args()
if args.experiment_dir is None:
args.experiment_dir = f"logs/{args.trainer}"

if args.trainer == "sf" and args.env_path is None:
print("WARNING: the sample-factory intergration is not designed to run in interactive mode, please export you game to use this trainer")


return args, extras


def main():
Expand Down
22 changes: 9 additions & 13 deletions godot_rl/wrappers/sample_factory_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import argparse
import sys
from functools import partial
import random
import numpy as np
from gym import spaces
from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
from sample_factory.envs.env_utils import register_env
from sample_factory.train import run_rl
from sample_factory.enjoy import enjoy

from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import lod_to_dol
from gymnasium import Env


class SampleFactoryEnvWrapperBatched(GodotEnv):
class SampleFactoryEnvWrapperBatched(GodotEnv, Env):
@property
def unwrapped(self):
return self
Expand All @@ -22,7 +20,7 @@ def unwrapped(self):
def num_agents(self):
return self.num_envs

def reset(self, seed=None):
def reset(self, seed=None, options=None):
obs, info = super().reset(seed=seed)
obs = lod_to_dol(obs)
return {k: np.array(v) for k, v in obs.items()}, info
Expand All @@ -45,22 +43,20 @@ def render():
return


class SampleFactoryEnvWrapperNonBatched(GodotEnv):
class SampleFactoryEnvWrapperNonBatched(GodotEnv, Env):
@property
def unwrapped(self):
return self

@property
def num_agents(self):
return self.num_envs

def reset(self, seed=None):
def reset(self, seed=None, options=None):
obs, info = super().reset(seed=seed)
return self.to_numpy(obs), info

def step(self, action):
obs, reward, term, trunc, info = super().step(action, order_ij=True)

return self.to_numpy(obs), np.array(reward), np.array(term), np.array(trunc) * 0, info

@staticmethod
Expand All @@ -78,7 +74,7 @@ def render():

def make_godot_env_func(env_path, full_env_name, cfg=None, env_config=None, render_mode=None, speedup=1, viz=False):
seed = 0
port = 21008 + cfg.base_port
port = cfg.base_port
print("BASE PORT ", cfg.base_port)
show_window = False
if env_config:
Expand Down Expand Up @@ -156,7 +152,7 @@ def add_gdrl_env_args(_env, p: argparse.ArgumentParser, evaluation=False):
# apparently env.render(mode="human") is not supported anymore and we need to specify the render mode in
# the env actor
p.add_argument("--render_mode", default="human", type=str, help="")
p.add_argument("--base_port", default=0, type=int, help="")
p.add_argument("--base_port", default=GodotEnv.DEFAULT_PORT, type=int, help="")

p.add_argument(
"--env_agents",
Expand All @@ -183,7 +179,7 @@ def parse_gdrl_args(argv=None, evaluation=False):
add_gdrl_env_args(partial_cfg.env, parser, evaluation=evaluation)
gdrl_override_defaults(partial_cfg.env, parser)
final_cfg = parse_full_cfg(parser, argv)
args, _ = parser.parse_known_args()
args, _ = parser.parse_known_args(argv)
final_cfg.train_dir = args.experiment_dir or "logs/sf"
final_cfg.experiment = args.experiment_name or final_cfg.experiment
return final_cfg
Expand All @@ -192,7 +188,7 @@ def parse_gdrl_args(argv=None, evaluation=False):
def sample_factory_training(args, extras):
register_gdrl_env(args)
cfg = parse_gdrl_args(argv=extras, evaluation=args.eval)
cfg.base_port = random.randint(20000, 22000)
#cfg.base_port = random.randint(20000, 22000)
status = run_rl(cfg)
return status

Expand Down
8 changes: 5 additions & 3 deletions godot_rl/wrappers/stable_baselines_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
Expand Down Expand Up @@ -109,8 +109,10 @@ def env_is_wrapped(self, wrapper_class: type, indices: Optional[List[int]] = Non
def env_method(self):
raise NotImplementedError()

def get_attr(self):
raise NotImplementedError()
def get_attr(self, attr_name: str, indices = None) -> List[Any]:
if attr_name == "render_mode":
return [None for _ in range(self.num_envs)]
raise AttributeError("get attr not fully implemented in godot-rl StableBaselinesWrapper")

def seed(self):
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "godot_rl"
version = "0.4.8"
version = "0.5.1"
authors = [
{ name="Edward Beeching", email="[email protected]" },
]
Expand Down
18 changes: 4 additions & 14 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ install_requires =
tensorboard
wget
huggingface_hub>=0.10
gym==0.26.2
stable-baselines3==1.2.0
gymnasium
stable-baselines3
huggingface_sb3
onnx
onnxruntime
Expand Down Expand Up @@ -44,14 +44,8 @@ dev =
isort>=5.0.0
pyyaml>=5.3.1

sb3 =
gym==0.26.2
stable-baselines3==1.2.0
huggingface_sb3

sf =
sample-factory==2.0.3
gym==0.26.2
sample-factory

rllib =
numpy==1.23.5
Expand All @@ -64,11 +58,7 @@ clean-rl =

all =
numpy==1.23.5
gym==0.26.2
stable-baselines3==1.2.0
sample-factory==2.0.3
sample-factory
ray==2.2.0
ray[rllib]

huggingface_sb3
tensorflow_probability
2 changes: 1 addition & 1 deletion tests/test_action_space_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from gym.spaces import Tuple, Dict, Box, Discrete
from gymnasium.spaces import Tuple, Dict, Box, Discrete
from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import ActionSpaceProcessor

Expand Down
16 changes: 16 additions & 0 deletions tests/test_sample_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from godot_rl.core.utils import cant_import

@pytest.mark.skipif(cant_import("sample_factory"), reason="sample_factory is not available")
def test_sample_factory_training():
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training
from examples.sample_factory_example import get_args
args, extras = get_args()
args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64"
extras = []
extras.append('--env=gdrl')
extras.append('--train_for_env_steps=1000')
extras.append('--device=cpu')

sample_factory_training(args, extras)
3 changes: 2 additions & 1 deletion tests/test_sb3_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_sb3_training(env_name, port, n_parallel):
args, extras = get_args()
args.env = "gdrl"
args.env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
args.experiment_name = f"test_{env_name}_{n_parallel}"
starting_port = port + n_parallel

stable_baselines_training(args, extras, n_steps=1000, port=starting_port, n_parallel=n_parallel)
stable_baselines_training(args, extras, n_steps=10, port=starting_port, n_parallel=n_parallel)

0 comments on commit f90378a

Please sign in to comment.