diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 5bc6933b..a3cb6c0b 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -55,3 +55,61 @@ jobs: - name: Test with pytest run: | make test + + + tests_ubuntu_rllib: + strategy: + matrix: + python-version: [3.8, 3.9, 3.10.10] + os: ['ubuntu-latest'] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel==0.38.4 + # cpu version of pytorch + pip install .[test] + - name: Clean up dependencies + run: | + pip uninstall -y stable-baselines3 gymnasium + pip install .[rllib] + - name: Download examples + run: | + make download_examples + + - name: Test with pytest + run: | + make test + tests_windows_rllib: + strategy: + matrix: + python-version: [3.8, 3.9, 3.10.10] + os: ['windows-latest'] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel==0.38.4 + # cpu version of pytorch + pip install .[test] + - name: Clean up dependencies + run: | + pip uninstall -y stable-baselines3 gymnasium + pip install .[rllib] + - name: Download examples + run: | + make download_examples + + - name: Test with pytest + run: | + make test diff --git a/.gitignore b/.gitignore index d22416ac..f5f13110 100644 --- a/.gitignore +++ b/.gitignore @@ -136,6 +136,7 @@ dmypy.json envs/unity/ logs/ +logs.*/ dump/ tmp/ Packaging Python Projects — Python Packaging User Guide_files/ diff --git a/docs/ADV_CLEAN_RL.md b/docs/ADV_CLEAN_RL.md index 04861b13..a622790f 100644 --- a/docs/ADV_CLEAN_RL.md +++ b/docs/ADV_CLEAN_RL.md @@ -17,11 +17,11 @@ You can read more about CleanRL in their [technical paper](https://arxiv.org/abs # Installation ```bash -pip install godot-rl[clean-rl] +pip install godot-rl[cleanrl] ``` -While the default options for clean-rl work reasonably well. You may be interested in changing the hyperparameters. -We recommend taking the [clean-rl example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/clean_rl_example.py) and modifying to match your needs. +While the default options for cleanrl work reasonably well. You may be interested in changing the hyperparameters. +We recommend taking the [cleanrl example](https://github.com/edbeeching/godot_rl_agents/blob/main/examples/clean_rl_example.py) and modifying to match your needs. ```python parser.add_argument("--gae-lambda", type=float, default=0.95, diff --git a/docs/ADV_RLLIB.md b/docs/ADV_RLLIB.md index 9f3eba77..0ae95ee5 100644 --- a/docs/ADV_RLLIB.md +++ b/docs/ADV_RLLIB.md @@ -4,9 +4,14 @@ ## Installation +If you want to train with rllib, create a new environment e.g.: `python -m venv venv.rllib` as rllib's dependencies can conflict with those of sb3 and other libraries. +Due to a version clash with gymnasium, stable-baselines3 must be uninstalled before installing rllib. ```bash -# remove sb3 installation with pip uninstall godot-rl[sb3] -pip install godot-rl[rllib] +pip install godot-rl +# remove sb3 and gymnasium installations +pip uninstall -y stable-baselines3 gymnasium +# install rllib +pip install ray[rllib] ``` ## Basic Environment Usage diff --git a/examples/clean_rl_example.py b/examples/clean_rl_example.py index 09b6f572..1f123b65 100644 --- a/examples/clean_rl_example.py +++ b/examples/clean_rl_example.py @@ -5,7 +5,6 @@ import time from distutils.util import strtobool from collections import deque -import gym import numpy as np import torch import torch.nn as nn @@ -17,6 +16,9 @@ def parse_args(): # fmt: off parser = argparse.ArgumentParser() + parser.add_argument("--viz", default=False, type=bool, + help="If set, the simulation will be displayed in a window during training. Otherwise " + "training will run without rendering the simualtion. This setting does not apply to in-editor training.") parser.add_argument("--experiment_dir", default="logs/cleanrl", type=str, help="The name of the experiment directory, in which the tensorboard logs are getting stored") parser.add_argument("--experiment_name", default=os.path.basename(__file__).rstrip(".py"), type=str, @@ -155,8 +157,7 @@ def get_action_and_value(self, x, action=None): # env setup - envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=True, speedup=args.speedup, convert_action_space=True) # Godot envs are already vectorized - #assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + envs = env = CleanRLGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, convert_action_space=True) # Godot envs are already vectorized args.num_envs = envs.num_envs args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) diff --git a/examples/sample_factory_example.py b/examples/sample_factory_example.py index 8c30aa8a..2c4e10a6 100644 --- a/examples/sample_factory_example.py +++ b/examples/sample_factory_example.py @@ -7,8 +7,17 @@ def get_args(): parser.add_argument("--env_path", default=None, type=str, help="Godot binary to use") parser.add_argument("--eval", default=False, action="store_true", help="whether to eval the model") parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env") - parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model") + parser.add_argument("--seed", default=0, type=int, help="environment seed") + parser.add_argument("--export", default=False, action="store_true", help="whether to export the model") parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") + parser.add_argument("--experiment_dir", default="logs/sf", type=str, + help="The name of the experiment directory, in which the tensorboard logs are getting stored") + parser.add_argument( + "--experiment_name", + default="experiment", + type=str, + help="The name of the experiment, which will be displayed in tensorboard. ", + ) return parser.parse_known_args() @@ -23,4 +32,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/stable_baselines3_example.py b/examples/stable_baselines3_example.py index fd9c22fc..ce035c5b 100644 --- a/examples/stable_baselines3_example.py +++ b/examples/stable_baselines3_example.py @@ -3,6 +3,7 @@ import pathlib from stable_baselines3.common.callbacks import CheckpointCallback +from godot_rl.core.utils import can_import from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx from stable_baselines3 import PPO @@ -11,7 +12,8 @@ # To download the env source and binary: # 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase # 2. chmod +x examples/godot_rl_BallChase/bin/BallChase.x86_64 - +if can_import("ray"): + print("WARNING, stable baselines and ray[rllib] are not compatable") parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument( @@ -34,6 +36,12 @@ help="The name of the experiment, which will be displayed in tensorboard and " "for checkpoint directory and name (if enabled).", ) +parser.add_argument( + "--seed", + type=int, + default=0, + help="seed of the experiment" +) parser.add_argument( "--resume_model_path", default=None, @@ -80,8 +88,8 @@ parser.add_argument( "--viz", action="store_true", - help="If set, the window(s) with the Godot environment(s) will be displayed, otherwise " - "training will run without rendering the game. Does not apply to in-editor training.", + help="If set, the simulation will be displayed in a window during training. Otherwise " + "training will run without rendering the simualtion. This setting does not apply to in-editor training.", default=False ) parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env") @@ -105,7 +113,7 @@ if args.env_path is None and args.viz: print("Info: Using --viz without --env_path set has no effect, in-editor training will always render.") -env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, n_parallel=args.n_parallel, +env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, seed=args.seed, n_parallel=args.n_parallel, speedup=args.speedup) env = VecMonitor(env) diff --git a/examples/stable_baselines3_hp_tuning.py b/examples/stable_baselines3_hp_tuning.py index 618f5c88..7e280f8c 100644 --- a/examples/stable_baselines3_hp_tuning.py +++ b/examples/stable_baselines3_hp_tuning.py @@ -23,7 +23,7 @@ from typing import Any from typing import Dict -import gym +import gymnasium as gym from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv from godot_rl.core.godot_env import GodotEnv diff --git a/godot_rl/core/godot_env.py b/godot_rl/core/godot_env.py index 6fed3637..1c888215 100644 --- a/godot_rl/core/godot_env.py +++ b/godot_rl/core/godot_env.py @@ -225,7 +225,7 @@ def reset(self, seed=None): response["obs"] = self._process_obs(response["obs"]) assert response["type"] == "reset" obs = response["obs"] - return obs, {} + return obs, [{}] * self.num_envs def call(self, method): message = { diff --git a/godot_rl/main.py b/godot_rl/main.py index be20f8ba..f856e673 100644 --- a/godot_rl/main.py +++ b/godot_rl/main.py @@ -62,6 +62,7 @@ def get_args(): parser.add_argument("--experiment_dir", default=None, type=str, help="The name of the the experiment directory, in which the tensorboard logs are getting stored") parser.add_argument("--experiment_name", default="experiment", type=str, help="The name of the the experiment, which will be displayed in tensborboard") parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") + parser.add_argument("--seed", default=0, type=int, help="seed of the experiment") args, extras = parser.parse_known_args() if args.experiment_dir is None: diff --git a/godot_rl/wrappers/clean_rl_wrapper.py b/godot_rl/wrappers/clean_rl_wrapper.py index 9f42f874..bd73d4e2 100644 --- a/godot_rl/wrappers/clean_rl_wrapper.py +++ b/godot_rl/wrappers/clean_rl_wrapper.py @@ -1,6 +1,6 @@ import numpy as np -import gym +import gymnasium as gym from godot_rl.core.utils import lod_to_dol from godot_rl.core.godot_env import GodotEnv diff --git a/godot_rl/wrappers/ray_wrapper.py b/godot_rl/wrappers/ray_wrapper.py index 1eb5b135..01a6c195 100644 --- a/godot_rl/wrappers/ray_wrapper.py +++ b/godot_rl/wrappers/ray_wrapper.py @@ -20,6 +20,7 @@ def __init__( show_window=False, framerate=None, action_repeat=None, + speedup=None, timeout_wait=60, config=None, ) -> None: @@ -31,6 +32,7 @@ def __init__( show_window=show_window, framerate=framerate, action_repeat=action_repeat, + speedup=speedup ) super().__init__( observation_space=self._env.observation_space, @@ -38,23 +40,28 @@ def __init__( num_envs=self._env.num_envs, ) - def vector_reset(self) -> List[EnvObsType]: - obs, info = self._env.reset() - return obs + def vector_reset(self, *, seeds: Optional[List[int]] = None, options: Optional[List[dict]] = None) -> List[EnvObsType]: + self.obs, info = self._env.reset() + return self.obs, info def vector_step( self, actions: List[EnvActionType] ) -> Tuple[List[EnvObsType], List[float], List[bool], List[EnvInfoDict]]: - actions = np.array(actions) + actions = np.array(actions, dtype=np.dtype(object)) self.obs, reward, term, trunc, info = self._env.step(actions, order_ij=True) - return self.obs, reward, term, info + return self.obs, reward, term, trunc, info def get_unwrapped(self): return [self._env] - def reset_at(self, index: Optional[int]) -> EnvObsType: + def reset_at(self, + index: Optional[int] = None, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> EnvObsType: # the env is reset automatically, no need to reset it - return self.obs[index] + return self.obs[index], {} def register_env(): @@ -68,6 +75,7 @@ def register_env(): framerate=c["framerate"], seed=c.worker_index + c["seed"], action_repeat=c["framerate"], + speedup=c["speedup"], ), ) @@ -118,6 +126,8 @@ def rllib_training(args, extras): register_env() exp["config"]["env_config"]["env_path"] = args.env_path + exp["config"]["env_config"]["seed"] = args.seed + if args.env_path is not None: run_name = exp["algorithm"] + "/" + pathlib.Path(args.env_path).stem else: @@ -133,6 +143,10 @@ def rllib_training(args, extras): checkpoint_freq = 10 checkpoint_at_end = True + + exp["config"]["env_config"]["show_window"] = args.viz + exp["config"]["env_config"]["speedup"] = args.speedup + if args.eval or args.export: checkpoint_freq = 0 exp["config"]["env_config"]["show_window"] = True diff --git a/godot_rl/wrappers/sample_factory_wrapper.py b/godot_rl/wrappers/sample_factory_wrapper.py index a5da8e7e..c38f8b44 100644 --- a/godot_rl/wrappers/sample_factory_wrapper.py +++ b/godot_rl/wrappers/sample_factory_wrapper.py @@ -72,32 +72,32 @@ def render(): return -def make_godot_env_func(env_path, full_env_name, cfg=None, env_config=None, render_mode=None, speedup=1, viz=False): - seed = 0 +def make_godot_env_func(env_path, full_env_name, cfg=None, env_config=None, render_mode=None, seed=0, speedup=1, viz=False): port = cfg.base_port print("BASE PORT ", cfg.base_port) show_window = False + _seed = seed if env_config: port += 1 + env_config.env_id - seed += 1 + env_config.env_id + _seed += 1 + env_config.env_id print("env id", env_config.env_id) if viz: # print("creating viz env") show_window = env_config.env_id == 0 if cfg.batched_sampling: env = SampleFactoryEnvWrapperBatched( - env_path=env_path, port=port, seed=seed, show_window=show_window, speedup=speedup + env_path=env_path, port=port, seed=_seed, show_window=show_window, speedup=speedup ) else: env = SampleFactoryEnvWrapperNonBatched( - env_path=env_path, port=port, seed=seed, show_window=show_window, speedup=speedup + env_path=env_path, port=port, seed=_seed, show_window=show_window, speedup=speedup ) return env def register_gdrl_env(args): - make_env = partial(make_godot_env_func, args.env_path, speedup=args.speedup, viz=args.viz) + make_env = partial(make_godot_env_func, args.env_path, speedup=args.speedup, seed=args.seed, viz=args.viz) register_env("gdrl", make_env) @@ -152,6 +152,7 @@ def add_gdrl_env_args(_env, p: argparse.ArgumentParser, evaluation=False): # apparently env.render(mode="human") is not supported anymore and we need to specify the render mode in # the env actor p.add_argument("--render_mode", default="human", type=str, help="") + p.add_argument("--base_port", default=GodotEnv.DEFAULT_PORT, type=int, help="") p.add_argument( @@ -160,26 +161,14 @@ def add_gdrl_env_args(_env, p: argparse.ArgumentParser, evaluation=False): type=int, help="Num agents in each envpool (if used)", ) - p.add_argument( - "--experiment_dir", - default="logs/sf", - type=str, - help="The name of the experiment directory, in which the tensorboard logs are getting stored", - ) - p.add_argument( - "--experiment_name", - default=None, - type=str, - help="The name of the experiment, which will be displayed in tensorboard", - ) -def parse_gdrl_args(argv=None, evaluation=False): +def parse_gdrl_args(args, argv=None, evaluation=False): parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation) add_gdrl_env_args(partial_cfg.env, parser, evaluation=evaluation) gdrl_override_defaults(partial_cfg.env, parser) final_cfg = parse_full_cfg(parser, argv) - args, _ = parser.parse_known_args(argv) + final_cfg.train_dir = args.experiment_dir or "logs/sf" final_cfg.experiment = args.experiment_name or final_cfg.experiment return final_cfg @@ -187,7 +176,7 @@ def parse_gdrl_args(argv=None, evaluation=False): def sample_factory_training(args, extras): register_gdrl_env(args) - cfg = parse_gdrl_args(argv=extras, evaluation=args.eval) + cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval) #cfg.base_port = random.randint(20000, 22000) status = run_rl(cfg) return status @@ -195,7 +184,7 @@ def sample_factory_training(args, extras): def sample_factory_enjoy(args, extras): register_gdrl_env(args) - cfg = parse_gdrl_args(argv=extras, evaluation=args.eval) + cfg = parse_gdrl_args(args=args, argv=extras, evaluation=args.eval) status = enjoy(cfg) return status diff --git a/godot_rl/wrappers/stable_baselines_wrapper.py b/godot_rl/wrappers/stable_baselines_wrapper.py index 03dad780..fb723e3d 100644 --- a/godot_rl/wrappers/stable_baselines_wrapper.py +++ b/godot_rl/wrappers/stable_baselines_wrapper.py @@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union from godot_rl.core.godot_env import GodotEnv -from godot_rl.core.utils import lod_to_dol +from godot_rl.core.utils import can_import, lod_to_dol class StableBaselinesGodotEnv(VecEnv): - def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, **kwargs) -> None: + def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, seed: int = 0, **kwargs) -> None: # If we are doing editor training, n_parallel must be 1 if env_path is None and n_parallel > 1: raise ValueError("You must provide the path to a exported game executable if n_parallel > 1") @@ -19,7 +19,7 @@ def __init__(self, env_path: Optional[str] = None, n_parallel: int = 1, **kwargs port = kwargs.pop("port", GodotEnv.DEFAULT_PORT) # Create a list of GodotEnv instances - self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port+p, seed=p, **kwargs) for p in range(n_parallel)] + self.envs = [GodotEnv(env_path=env_path, convert_action_space=True, port=port+p, seed=seed+p, **kwargs) for p in range(n_parallel)] # Store the number of parallel environments self.n_parallel = n_parallel @@ -114,7 +114,7 @@ def get_attr(self, attr_name: str, indices = None) -> List[Any]: return [None for _ in range(self.num_envs)] raise AttributeError("get attr not fully implemented in godot-rl StableBaselinesWrapper") - def seed(self): + def seed(self, seed = None): raise NotImplementedError() def set_attr(self): @@ -129,6 +129,8 @@ def step_wait(self) -> Tuple[Dict[str, np.ndarray], np.ndarray, np.ndarray, List return self.results def stable_baselines_training(args, extras, n_steps: int = 200000, **kwargs) -> None: + if can_import("ray"): + print("WARNING, stable baselines and ray[rllib] are not compatable") # Initialize the custom environment env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, speedup=args.speedup, **kwargs) env = VecMonitor(env) diff --git a/godot_rl_agents_plugin b/godot_rl_agents_plugin index 3984fd12..5b09dc90 160000 --- a/godot_rl_agents_plugin +++ b/godot_rl_agents_plugin @@ -1 +1 @@ -Subproject commit 3984fd124a2b941a446f4614bb0eacd09a2468f5 +Subproject commit 5b09dc906eae1e037c4f8b0b09a1ffe11340802f diff --git a/pyproject.toml b/pyproject.toml index f2fe56de..f2ce4faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "godot_rl" -version = "0.6.0" +version = "0.6.1" authors = [ { name="Edward Beeching", email="edbeeching@gmail.com" }, ] diff --git a/setup.cfg b/setup.cfg index ddba42b9..a6dc219f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,17 +48,9 @@ sf = sample-factory rllib = - numpy==1.23.5 - ray==2.2.0 + gymnasium==0.26.3 ray[rllib] - tensorflow_probability -clean-rl = +cleanrl = wandb -all = - numpy==1.23.5 - sample-factory - ray==2.2.0 - ray[rllib] - tensorflow_probability diff --git a/tests/fixtures/test_rllib.yaml b/tests/fixtures/test_rllib.yaml new file mode 100644 index 00000000..3c237b3c --- /dev/null +++ b/tests/fixtures/test_rllib.yaml @@ -0,0 +1,39 @@ + +algorithm: PPO + +stop: + episode_reward_mean: 5000 + training_iteration: 1000 + timesteps_total: 100 + +config: + env: godot + env_config: + framerate: null + action_repeat: null + show_window: false + seed: 0 + framework: torch + lambda: 0.95 + gamma: 0.95 + + vf_clip_param: 1.0 + clip_param: 0.2 + entropy_coeff: 0.001 + entropy_coeff_schedule: null + train_batch_size: 1024 + sgd_minibatch_size: 128 + num_sgd_iter: 16 + num_workers: 1 + lr: 0.0003 + num_envs_per_worker: 16 + batch_mode: truncate_episodes + rollout_fragment_length: 16 + num_gpus: 1 + model: + fcnet_hiddens: [256, 256] + use_lstm: false + lstm_cell_size : 32 + framestack: 4 + no_done_at_end: false + soft_horizon: false diff --git a/tests/test_rllib.py b/tests/test_rllib.py new file mode 100644 index 00000000..d6eb62f5 --- /dev/null +++ b/tests/test_rllib.py @@ -0,0 +1,14 @@ +import pytest + +from godot_rl.core.utils import cant_import + +@pytest.mark.skipif(cant_import("ray"), reason="ray[rllib] is not available") +def test_rllib_training(): + from godot_rl.wrappers.ray_wrapper import rllib_training + from godot_rl.main import get_args + args, extras = get_args() + args.config_file = "tests/fixtures/test_rllib.yaml" + args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" + + + rllib_training(args, extras) \ No newline at end of file diff --git a/tests/test_sample_factory.py b/tests/test_sample_factory.py index 7f6b686d..eaa9c826 100644 --- a/tests/test_sample_factory.py +++ b/tests/test_sample_factory.py @@ -13,4 +13,5 @@ def test_sample_factory_training(): extras.append('--train_for_env_steps=1000') extras.append('--device=cpu') - sample_factory_training(args, extras) \ No newline at end of file + sample_factory_training(args, extras) + diff --git a/tests/test_sb3_onnx_export.py b/tests/test_sb3_onnx_export.py index b910cfce..4dc1c4eb 100644 --- a/tests/test_sb3_onnx_export.py +++ b/tests/test_sb3_onnx_export.py @@ -1,13 +1,10 @@ import os import pytest -from stable_baselines3 import PPO - -from godot_rl.wrappers.onnx.stable_baselines_export import ( - export_ppo_model_as_onnx, verify_onnx_export) -from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv +from godot_rl.core.utils import can_import +@pytest.mark.skipif(can_import("ray"), reason="rllib and sb3 are not compatable") @pytest.mark.parametrize( "env_name,port", [ @@ -19,6 +16,10 @@ ], ) def test_pytorch_vs_onnx(env_name, port): + from stable_baselines3 import PPO + from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv + from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx, verify_onnx_export + env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64" env = StableBaselinesGodotEnv(env_path, port=port) diff --git a/tests/test_sb3_training.py b/tests/test_sb3_training.py index 864ed482..01071732 100644 --- a/tests/test_sb3_training.py +++ b/tests/test_sb3_training.py @@ -1,15 +1,9 @@ import pytest -from godot_rl.core.godot_env import GodotEnv from godot_rl.main import get_args +from godot_rl.core.utils import can_import -try: - from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training -except ImportError as e: - - def stable_baselines_training(args, extras, **kwargs): - print("Import error when trying to use sb3, this is probably not installed try pip install godot-rl[sb3]") - +@pytest.mark.skipif(can_import("ray"), reason="rllib and sb3 are not compatable") @pytest.mark.parametrize( "env_name,port", [ @@ -20,13 +14,9 @@ def stable_baselines_training(args, extras, **kwargs): ("FlyBy", 12400), ], ) -@pytest.mark.parametrize( - "n_parallel",[ - 1,2,4 - ] - -) +@pytest.mark.parametrize("n_parallel",[1,2,4]) def test_sb3_training(env_name, port, n_parallel): + from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training args, extras = get_args() args.env = "gdrl" args.env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"