Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rllib #140

Merged
merged 32 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2a335c3
fixes info shape on reset
edbeeching Jul 27, 2023
b78b1be
updates RLLIB wrapper to support latest version
edbeeching Jul 27, 2023
cc72376
update setup and toml file for a minor release
edbeeching Jul 27, 2023
6a735d1
updates docs
edbeeching Jul 27, 2023
326b188
chore(addon): use addon submodule from main branch
visuallization Jul 27, 2023
20c4df7
fix(ray): make mixed action types work with latest numpy version
visuallization Jul 27, 2023
1d4baa5
fix(cleanrl): make cleanrl work again
visuallization Jul 27, 2023
1a1b18f
fix(hp tuning): use gymnasium instead of gym
visuallization Jul 27, 2023
ac1c4bb
adds tests for rllib
edbeeching Jul 27, 2023
0f1d1b1
Update test-ci.yml
edbeeching Jul 27, 2023
a38e4b9
hacky fix to get rllib test to work
edbeeching Jul 27, 2023
c67b68b
last try before I give up
edbeeching Jul 27, 2023
fd7d277
Update .github/workflows/test-ci.yml
Ivan-267 Jul 27, 2023
d56cf4d
Update .github/workflows/test-ci.yml
Ivan-267 Jul 27, 2023
9a5e409
fix(test): split dependency install and cleanup into two steps and se…
visuallization Jul 28, 2023
d1d2e80
fix(test): install rllib from setup.cfg
visuallization Jul 28, 2023
6bc9539
feat(cleanrl): added --viz option to example
visuallization Jul 28, 2023
696f872
fix(test): also update pip in cleanup step
visuallization Jul 28, 2023
888dec7
fix(test): set wheel version
visuallization Jul 28, 2023
6468858
fix(test): also fix windows rllib pip instal
visuallization Jul 28, 2023
0ef70b4
feat(sb3): make it possible to set seed
visuallization Jul 28, 2023
44ac81e
feat(sf): make it possible to set seed
visuallization Jul 28, 2023
4a1a016
fix(sf): make experiment_dir work again
visuallization Jul 28, 2023
69f14b5
feat(rllib): make it possible to set a seed
visuallization Jul 28, 2023
fc20ae4
1 worker, even less timesteps
Ivan-267 Jul 28, 2023
776258c
Attempting to add seed arg
Ivan-267 Jul 28, 2023
ccb7b7f
reverting change (didn't fix the issue)
Ivan-267 Jul 28, 2023
37edec8
Another attempt to add seed
Ivan-267 Jul 28, 2023
e615641
Add --experiment_dir
Ivan-267 Jul 28, 2023
648a152
Add --experiment_name
Ivan-267 Jul 28, 2023
d7488af
updates installtion instructions for rllib
edbeeching Jul 31, 2023
8b21238
Update ADV_RLLIB.md
Ivan-267 Jul 31, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,57 @@ jobs:
- name: Test with pytest
run: |
make test


tests_ubuntu_rllib:
strategy:
matrix:
python-version: [3.8, 3.9, 3.10.10]
os: ['ubuntu-latest']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test]
pip uninstall -y stable-baselines3 gymnasium
pip install -y ray[rllib]
Ivan-267 marked this conversation as resolved.
Show resolved Hide resolved
- name: Download examples
run: |
make download_examples

- name: Test with pytest
run: |
make test
tests_windows_rllib:
strategy:
matrix:
python-version: [3.8, 3.9, 3.10.10]
os: ['windows-latest']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test]
pip uninstall -y stable-baselines3 gymnasium
pip install -y ray[rllib]
Ivan-267 marked this conversation as resolved.
Show resolved Hide resolved
- name: Download examples
run: |
make download_examples

- name: Test with pytest
run: |
make test
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ dmypy.json

envs/unity/
logs/
logs.*/
dump/
tmp/
Packaging Python Projects — Python Packaging User Guide_files/
Expand Down
6 changes: 3 additions & 3 deletions docs/ADV_CLEAN_RL.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ You can read more about CleanRL in their [technical paper](https://arxiv.org/abs

# Installation
```bash
pip install godot-rl[clean-rl]
pip install godot-rl[cleanrl]
```

While the default options for clean-rl work reasonably well. You may be interested in changing the hyperparameters.
We recommend taking the [clean-rl example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/clean_rl_example.py) and modifying to match your needs.
While the default options for cleanrl work reasonably well. You may be interested in changing the hyperparameters.
We recommend taking the [cleanrl example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/clean_rl_example.py) and modifying to match your needs.

```python
parser.add_argument("--gae-lambda", type=float, default=0.95,
Expand Down
2 changes: 2 additions & 0 deletions docs/ADV_RLLIB.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## Installation

If you want to train with rllib, create a new environment e.g.: `python -m venv venv.rllib` as rllib's dependencies can conflict with those of sb3 and other libraries.

```bash
# remove sb3 installation with pip uninstall godot-rl[sb3]
pip install godot-rl[rllib]
Expand Down
2 changes: 0 additions & 2 deletions examples/clean_rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
from distutils.util import strtobool
from collections import deque
import gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -156,7 +155,6 @@ def get_action_and_value(self, x, action=None):
# env setup

envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup, convert_action_space=True) # Godot envs are already vectorized
#assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
args.num_envs = envs.num_envs
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
Expand Down
4 changes: 3 additions & 1 deletion examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib

from stable_baselines3.common.callbacks import CheckpointCallback
from godot_rl.core.utils import can_import
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from stable_baselines3 import PPO
Expand All @@ -11,7 +12,8 @@
# To download the env source and binary:
# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
# 2. chmod +x examples/godot_rl_BallChase/bin/BallChase.x86_64

if can_import("ray"):
print("WARNING, stable baselines and ray[rllib] are not compatable")

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/stable_baselines3_hp_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Any
from typing import Dict

import gym
import gymnasium as gym

from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.core.godot_env import GodotEnv
Expand Down
2 changes: 1 addition & 1 deletion godot_rl/core/godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def reset(self, seed=None):
response["obs"] = self._process_obs(response["obs"])
assert response["type"] == "reset"
obs = response["obs"]
return obs, {}
return obs, [{}] * self.num_envs

def call(self, method):
message = {
Expand Down
2 changes: 1 addition & 1 deletion godot_rl/wrappers/clean_rl_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import numpy as np
import gym
import gymnasium as gym
from godot_rl.core.utils import lod_to_dol
from godot_rl.core.godot_env import GodotEnv

Expand Down
26 changes: 19 additions & 7 deletions godot_rl/wrappers/ray_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
show_window=False,
framerate=None,
action_repeat=None,
speedup=None,
timeout_wait=60,
config=None,
) -> None:
Expand All @@ -31,30 +32,36 @@ def __init__(
show_window=show_window,
framerate=framerate,
action_repeat=action_repeat,
speedup=speedup
)
super().__init__(
observation_space=self._env.observation_space,
action_space=self._env.action_space,
num_envs=self._env.num_envs,
)

def vector_reset(self) -> List[EnvObsType]:
obs, info = self._env.reset()
return obs
def vector_reset(self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None) -> List[EnvObsType]:
self.obs, info = self._env.reset()
return self.obs, info

def vector_step(
self, actions: List[EnvActionType]
) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]:
actions = np.array(actions)
actions = np.array(actions, dtype=np.dtype(object))
self.obs, reward, term, trunc, info = self._env.step(actions, order_ij=True)
return self.obs, reward, term, info
return self.obs, reward, term, trunc, info

def get_unwrapped(self):
return [self._env]

def reset_at(self, index: Optional[int]) -> EnvObsType:
def reset_at(self,
index: Optional[int] = None,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> EnvObsType:
# the env is reset automatically, no need to reset it
return self.obs[index]
return self.obs[index], {}


def register_env():
Expand All @@ -68,6 +75,7 @@ def register_env():
framerate=c["framerate"],
seed=c.worker_index + c["seed"],
action_repeat=c["framerate"],
speedup=c["speedup"],
),
)

Expand Down Expand Up @@ -133,6 +141,10 @@ def rllib_training(args, extras):

checkpoint_freq = 10
checkpoint_at_end = True

exp["config"]["env_config"]["show_window"] = args.viz
exp["config"]["env_config"]["speedup"] = args.speedup

if args.eval or args.export:
checkpoint_freq = 0
exp["config"]["env_config"]["show_window"] = True
Expand Down
4 changes: 3 additions & 1 deletion godot_rl/wrappers/stable_baselines_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from godot_rl.core.godot_env import GodotEnv
from godot_rl.core.utils import lod_to_dol
from godot_rl.core.utils import can_import, lod_to_dol


class StableBaselinesGodotEnv(VecEnv):
Expand Down Expand Up @@ -129,6 +129,8 @@ def step_wait(self) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List
return self.results

def stable_baselines_training(args, extras, n_steps: int = 200000, **kwargs) -> None:
if can_import("ray"):
print("WARNING, stable baselines and ray[rllib] are not compatable")
# Initialize the custom environment
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, **kwargs)
env = VecMonitor(env)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "godot_rl"
version = "0.6.0"
version = "0.6.1"
authors = [
{ name="Edward Beeching", email="[email protected]" },
]
Expand Down
12 changes: 2 additions & 10 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,9 @@ sf =
sample-factory

rllib =
numpy==1.23.5
ray==2.2.0
gymnasium==0.26.3
ray[rllib]
tensorflow_probability

clean-rl =
cleanrl =
wandb

all =
numpy==1.23.5
sample-factory
ray==2.2.0
ray[rllib]
tensorflow_probability
39 changes: 39 additions & 0 deletions tests/fixtures/test_rllib.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

algorithm: PPO

stop:
episode_reward_mean: 5000
training_iteration: 1000
timesteps_total: 200

config:
env: godot
env_config:
framerate: null
action_repeat: null
show_window: false
seed: 0
framework: torch
lambda: 0.95
gamma: 0.95

vf_clip_param: 1.0
clip_param: 0.2
entropy_coeff: 0.001
entropy_coeff_schedule: null
train_batch_size: 1024
sgd_minibatch_size: 128
num_sgd_iter: 16
num_workers: 4
lr: 0.0003
num_envs_per_worker: 16
batch_mode: truncate_episodes
rollout_fragment_length: 16
num_gpus: 1
model:
fcnet_hiddens: [256, 256]
use_lstm: false
lstm_cell_size : 32
framestack: 4
no_done_at_end: false
soft_horizon: false
14 changes: 14 additions & 0 deletions tests/test_rllib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from godot_rl.core.utils import cant_import

@pytest.mark.skipif(cant_import("ray"), reason="ray[rllib] is not available")
def test_rllib_training():
from godot_rl.wrappers.ray_wrapper import rllib_training
from godot_rl.main import get_args
args, extras = get_args()
args.config_file = "tests/fixtures/test_rllib.yaml"
args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64"


rllib_training(args, extras)
3 changes: 2 additions & 1 deletion tests/test_sample_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ def test_sample_factory_training():
extras.append('--train_for_env_steps=1000')
extras.append('--device=cpu')

sample_factory_training(args, extras)
sample_factory_training(args, extras)

11 changes: 6 additions & 5 deletions tests/test_sb3_onnx_export.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import os

import pytest
from stable_baselines3 import PPO

from godot_rl.wrappers.onnx.stable_baselines_export import (
export_ppo_model_as_onnx, verify_onnx_export)
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

from godot_rl.core.utils import can_import

@pytest.mark.skipif(can_import("ray"), reason="rllib and sb3 are not compatable")
@pytest.mark.parametrize(
"env_name,port",
[
Expand All @@ -19,6 +16,10 @@
],
)
def test_pytorch_vs_onnx(env_name, port):
from stable_baselines3 import PPO
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx, verify_onnx_export

env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
env = StableBaselinesGodotEnv(env_path, port=port)

Expand Down
18 changes: 4 additions & 14 deletions tests/test_sb3_training.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import pytest

from godot_rl.core.godot_env import GodotEnv
from godot_rl.main import get_args
from godot_rl.core.utils import can_import

try:
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training
except ImportError as e:

def stable_baselines_training(args, extras, **kwargs):
print("Import error when trying to use sb3, this is probably not installed try pip install godot-rl[sb3]")

@pytest.mark.skipif(can_import("ray"), reason="rllib and sb3 are not compatable")
@pytest.mark.parametrize(
"env_name,port",
[
Expand All @@ -20,13 +14,9 @@ def stable_baselines_training(args, extras, **kwargs):
("FlyBy", 12400),
],
)
@pytest.mark.parametrize(
"n_parallel",[
1,2,4
]

)
@pytest.mark.parametrize("n_parallel",[1,2,4])
def test_sb3_training(env_name, port, n_parallel):
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training
args, extras = get_args()
args.env = "gdrl"
args.env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
Expand Down