From 265fe2b5073d47ca534b02789377f4f91140dbf2 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 18 Jul 2023 22:04:09 +0200 Subject: [PATCH 01/18] adds python 3.10 to CI tests --- .github/workflows/test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 7a32e7ed..ef407614 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -10,7 +10,7 @@ jobs: tests: strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9, 3.10] os: ['ubuntu-latest', 'windows-latest'] runs-on: ${{ matrix.os }} steps: From d41e2a464e0d52f1863a1b8a62e46bd1bdadf29b Mon Sep 17 00:00:00 2001 From: edbeeching Date: Tue, 18 Jul 2023 22:06:40 +0200 Subject: [PATCH 02/18] updates python version in tests --- .github/workflows/test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index ef407614..f46900aa 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -10,7 +10,7 @@ jobs: tests: strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.10] + python-version: [3.7, 3.8, 3.9, 3.10.10] os: ['ubuntu-latest', 'windows-latest'] runs-on: ${{ matrix.os }} steps: From 0638d03fa2587fd7633b02963d22dd9dfdab4440 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Wed, 19 Jul 2023 17:03:03 +0200 Subject: [PATCH 03/18] moves to gymnasium --- godot_rl/core/godot_env.py | 2 +- godot_rl/core/utils.py | 2 +- godot_rl/wrappers/stable_baselines_wrapper.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/godot_rl/core/godot_env.py b/godot_rl/core/godot_env.py index daef8d6d..6fed3637 100644 --- a/godot_rl/core/godot_env.py +++ b/godot_rl/core/godot_env.py @@ -8,7 +8,7 @@ from sys import platform import numpy as np -from gym import spaces +from gymnasium import spaces from typing import Optional from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path diff --git a/godot_rl/core/utils.py b/godot_rl/core/utils.py index aff5b6f3..5fcb38ec 100644 --- a/godot_rl/core/utils.py +++ b/godot_rl/core/utils.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import numpy as np import re diff --git a/godot_rl/wrappers/stable_baselines_wrapper.py b/godot_rl/wrappers/stable_baselines_wrapper.py index 2134c6bc..814753a2 100644 --- a/godot_rl/wrappers/stable_baselines_wrapper.py +++ b/godot_rl/wrappers/stable_baselines_wrapper.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.vec_env.base_vec_env import VecEnv @@ -109,8 +109,8 @@ def env_is_wrapped(self, wrapper_class: type, indices: Optional[List[int]] = Non def env_method(self): raise NotImplementedError() - def get_attr(self): - raise NotImplementedError() + def get_attr(self, attr_name: str, indices = None) -> List[Any]: + raise AttributeError("get attr not implemented in godot-rl StableBaselinesWrapper") def seed(self): raise NotImplementedError() From 4153088b4ca18d5fe7d352043749e0331bb8a2da Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 10:26:15 +0200 Subject: [PATCH 04/18] fixes test --- tests/test_sb3_training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sb3_training.py b/tests/test_sb3_training.py index 5d8130f0..864ed482 100644 --- a/tests/test_sb3_training.py +++ b/tests/test_sb3_training.py @@ -30,6 +30,7 @@ def test_sb3_training(env_name, port, n_parallel): args, extras = get_args() args.env = "gdrl" args.env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64" + args.experiment_name = f"test_{env_name}_{n_parallel}" starting_port = port + n_parallel - stable_baselines_training(args, extras, n_steps=1000, port=starting_port, n_parallel=n_parallel) + stable_baselines_training(args, extras, n_steps=10, port=starting_port, n_parallel=n_parallel) From b9e94bcfb7f8432455b5c1d746e65a571ce871f8 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 10:28:18 +0200 Subject: [PATCH 05/18] updates setup with sb3 2.0 --- setup.cfg | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/setup.cfg b/setup.cfg index e3c54c2b..ddba42b9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,8 +15,8 @@ install_requires = tensorboard wget huggingface_hub>=0.10 - gym==0.26.2 - stable-baselines3==1.2.0 + gymnasium + stable-baselines3 huggingface_sb3 onnx onnxruntime @@ -44,14 +44,8 @@ dev = isort>=5.0.0 pyyaml>=5.3.1 -sb3 = - gym==0.26.2 - stable-baselines3==1.2.0 - huggingface_sb3 - sf = - sample-factory==2.0.3 - gym==0.26.2 + sample-factory rllib = numpy==1.23.5 @@ -64,11 +58,7 @@ clean-rl = all = numpy==1.23.5 - gym==0.26.2 - stable-baselines3==1.2.0 - sample-factory==2.0.3 + sample-factory ray==2.2.0 ray[rllib] - - huggingface_sb3 tensorflow_probability From d4c819cc2846dfd6f62725163f5309c1f191c406 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 10:33:18 +0200 Subject: [PATCH 06/18] updates to gymnasium --- tests/test_action_space_preprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_action_space_preprocessor.py b/tests/test_action_space_preprocessor.py index 5685e58f..4496bbe5 100644 --- a/tests/test_action_space_preprocessor.py +++ b/tests/test_action_space_preprocessor.py @@ -1,5 +1,5 @@ import pytest -from gym.spaces import Tuple, Dict, Box, Discrete +from gymnasium.spaces import Tuple, Dict, Box, Discrete from godot_rl.core.godot_env import GodotEnv from godot_rl.core.utils import ActionSpaceProcessor From 87d90b4836aaaad547c29db6cfbbf18a0e01d5a4 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 15:06:07 +0200 Subject: [PATCH 07/18] removes render warning --- godot_rl/wrappers/stable_baselines_wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/godot_rl/wrappers/stable_baselines_wrapper.py b/godot_rl/wrappers/stable_baselines_wrapper.py index 814753a2..03dad780 100644 --- a/godot_rl/wrappers/stable_baselines_wrapper.py +++ b/godot_rl/wrappers/stable_baselines_wrapper.py @@ -110,7 +110,9 @@ def env_method(self): raise NotImplementedError() def get_attr(self, attr_name: str, indices = None) -> List[Any]: - raise AttributeError("get attr not implemented in godot-rl StableBaselinesWrapper") + if attr_name == "render_mode": + return [None for _ in range(self.num_envs)] + raise AttributeError("get attr not fully implemented in godot-rl StableBaselinesWrapper") def seed(self): raise NotImplementedError() From cfbe0b83996c482f3e48f26a22349837ac1decd2 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 15:42:13 +0200 Subject: [PATCH 08/18] removes capitalization --- examples/stable_baselines3_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stable_baselines3_example.py b/examples/stable_baselines3_example.py index d13a85b9..2d49a997 100644 --- a/examples/stable_baselines3_example.py +++ b/examples/stable_baselines3_example.py @@ -25,7 +25,7 @@ ) parser.add_argument( "--experiment_name", - default="Experiment", + default="experiment", type=str, help="The name of the experiment, which will be displayed in tensorboard", ) From f2c26d9c7eedcd46ab486f5cdd288b42640b01f6 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 15:42:45 +0200 Subject: [PATCH 09/18] adds warning for when using sf in editor mode --- godot_rl/main.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/godot_rl/main.py b/godot_rl/main.py index 329049a7..be20f8ba 100644 --- a/godot_rl/main.py +++ b/godot_rl/main.py @@ -60,10 +60,18 @@ def get_args(): parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model") parser.add_argument("--num_gpus", default=None, type=int, help="Number of GPUs to use [only for rllib]") parser.add_argument("--experiment_dir", default=None, type=str, help="The name of the the experiment directory, in which the tensorboard logs are getting stored") - parser.add_argument("--experiment_name", default=None, type=str, help="The name of the the experiment, which will be displayed in tensborboard") + parser.add_argument("--experiment_name", default="experiment", type=str, help="The name of the the experiment, which will be displayed in tensborboard") parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") - - return parser.parse_known_args() + + args, extras = parser.parse_known_args() + if args.experiment_dir is None: + args.experiment_dir = f"logs/{args.trainer}" + + if args.trainer == "sf" and args.env_path is None: + print("WARNING: the sample-factory intergration is not designed to run in interactive mode, please export you game to use this trainer") + + + return args, extras def main(): From b45e6c1bc59bc84fbba3d825f263c75b1a9e3a48 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 15:43:38 +0200 Subject: [PATCH 10/18] updates wrapper to comply with gymnasium API, removes some unneeded features --- godot_rl/wrappers/sample_factory_wrapper.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/godot_rl/wrappers/sample_factory_wrapper.py b/godot_rl/wrappers/sample_factory_wrapper.py index 7b0eb20e..ba7d54db 100644 --- a/godot_rl/wrappers/sample_factory_wrapper.py +++ b/godot_rl/wrappers/sample_factory_wrapper.py @@ -1,9 +1,7 @@ import argparse -import sys from functools import partial import random import numpy as np -from gym import spaces from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl @@ -11,9 +9,9 @@ from godot_rl.core.godot_env import GodotEnv from godot_rl.core.utils import lod_to_dol +from gymnasium import Env - -class SampleFactoryEnvWrapperBatched(GodotEnv): +class SampleFactoryEnvWrapperBatched(GodotEnv, Env): @property def unwrapped(self): return self @@ -22,7 +20,7 @@ def unwrapped(self): def num_agents(self): return self.num_envs - def reset(self, seed=None): + def reset(self, seed=None, options=None): obs, info = super().reset(seed=seed) obs = lod_to_dol(obs) return {k: np.array(v) for k, v in obs.items()}, info @@ -45,7 +43,7 @@ def render(): return -class SampleFactoryEnvWrapperNonBatched(GodotEnv): +class SampleFactoryEnvWrapperNonBatched(GodotEnv, Env): @property def unwrapped(self): return self @@ -53,14 +51,12 @@ def unwrapped(self): @property def num_agents(self): return self.num_envs - - def reset(self, seed=None): + def reset(self, seed=None, options=None): obs, info = super().reset(seed=seed) return self.to_numpy(obs), info def step(self, action): obs, reward, term, trunc, info = super().step(action, order_ij=True) - return self.to_numpy(obs), np.array(reward), np.array(term), np.array(trunc) * 0, info @staticmethod @@ -78,7 +74,7 @@ def render(): def make_godot_env_func(env_path, full_env_name, cfg=None, env_config=None, render_mode=None, speedup=1, viz=False): seed = 0 - port = 21008 + cfg.base_port + port = cfg.base_port print("BASE PORT ", cfg.base_port) show_window = False if env_config: @@ -156,7 +152,7 @@ def add_gdrl_env_args(_env, p: argparse.ArgumentParser, evaluation=False): # apparently env.render(mode="human") is not supported anymore and we need to specify the render mode in # the env actor p.add_argument("--render_mode", default="human", type=str, help="") - p.add_argument("--base_port", default=0, type=int, help="") + p.add_argument("--base_port", default=GodotEnv.DEFAULT_PORT, type=int, help="") p.add_argument( "--env_agents", @@ -192,7 +188,7 @@ def parse_gdrl_args(argv=None, evaluation=False): def sample_factory_training(args, extras): register_gdrl_env(args) cfg = parse_gdrl_args(argv=extras, evaluation=args.eval) - cfg.base_port = random.randint(20000, 22000) + #cfg.base_port = random.randint(20000, 22000) status = run_rl(cfg) return status From baaa703cc70db9f5b55dfe9b10d7874164e23ff7 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 20 Jul 2023 17:17:46 +0200 Subject: [PATCH 11/18] adds tests, bumps version --- examples/sample_factory_example.py | 26 +++++++++++++++++++++ godot_rl/core/utils.py | 15 +++++++++++- godot_rl/wrappers/sample_factory_wrapper.py | 2 +- pyproject.toml | 2 +- tests/test_sample_factory.py | 15 ++++++++++++ 5 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 examples/sample_factory_example.py create mode 100644 tests/test_sample_factory.py diff --git a/examples/sample_factory_example.py b/examples/sample_factory_example.py new file mode 100644 index 00000000..8c30aa8a --- /dev/null +++ b/examples/sample_factory_example.py @@ -0,0 +1,26 @@ +import argparse +from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy + + +def get_args(): + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument("--env_path", default=None, type=str, help="Godot binary to use") + parser.add_argument("--eval", default=False, action="store_true", help="whether to eval the model") + parser.add_argument("--speedup", default=1, type=int, help="whether to speed up the physics in the env") + parser.add_argument("--export", default=False, action="store_true", help="wheter to export the model") + parser.add_argument("--viz", default=False, action="store_true", help="Whether to visualize one process") + + return parser.parse_known_args() + + + +def main(): + args, extras = get_args() + if args.eval: + sample_factory_enjoy(args, extras) + else: + sample_factory_training(args, extras) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/godot_rl/core/utils.py b/godot_rl/core/utils.py index 5fcb38ec..edca98d7 100644 --- a/godot_rl/core/utils.py +++ b/godot_rl/core/utils.py @@ -1,6 +1,9 @@ +import importlib +import re + import gymnasium as gym import numpy as np -import re + def lod_to_dol(lod): @@ -103,3 +106,13 @@ def to_original_dist(self, action): raise NotImplementedError return original_action + +def can_import(module_name): + return not cant_import(module_name) + +def cant_import(module_name): + try: + importlib.import_module(module_name) + return False + except ImportError: + return True \ No newline at end of file diff --git a/godot_rl/wrappers/sample_factory_wrapper.py b/godot_rl/wrappers/sample_factory_wrapper.py index ba7d54db..a5da8e7e 100644 --- a/godot_rl/wrappers/sample_factory_wrapper.py +++ b/godot_rl/wrappers/sample_factory_wrapper.py @@ -179,7 +179,7 @@ def parse_gdrl_args(argv=None, evaluation=False): add_gdrl_env_args(partial_cfg.env, parser, evaluation=evaluation) gdrl_override_defaults(partial_cfg.env, parser) final_cfg = parse_full_cfg(parser, argv) - args, _ = parser.parse_known_args() + args, _ = parser.parse_known_args(argv) final_cfg.train_dir = args.experiment_dir or "logs/sf" final_cfg.experiment = args.experiment_name or final_cfg.experiment return final_cfg diff --git a/pyproject.toml b/pyproject.toml index 7ac83575..a7cc47bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "godot_rl" -version = "0.4.8" +version = "0.5.1" authors = [ { name="Edward Beeching", email="edbeeching@gmail.com" }, ] diff --git a/tests/test_sample_factory.py b/tests/test_sample_factory.py new file mode 100644 index 00000000..a6340bfa --- /dev/null +++ b/tests/test_sample_factory.py @@ -0,0 +1,15 @@ +import pytest + +from godot_rl.core.utils import cant_import +from examples.sample_factory_example import get_args + +@pytest.mark.skipif(cant_import("sample_factory"), reason="sample_factory is not available") +def test_sample_factory_training(): + from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training + args, extras = get_args() + args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" + extras = [] + extras.append('--env=gdrl') + extras.extend(['--train_for_env_steps=1000']) + + sample_factory_training(args, extras) \ No newline at end of file From 6057678d1ced0930ec6efac7bb19e41579a5a5c3 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Fri, 21 Jul 2023 15:26:50 +0200 Subject: [PATCH 12/18] updates ci config to include sf --- .github/workflows/test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index f46900aa..d0e2c41c 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -23,7 +23,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install .[test] + pip install .[test, sf] - name: Download examples run: | make download_examples From 3e71be97dddb2239e179a74d410e6fdfe0860b42 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sat, 22 Jul 2023 21:13:39 +0200 Subject: [PATCH 13/18] removes whitespace --- .github/workflows/test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index d0e2c41c..48fa35a8 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -23,7 +23,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install .[test, sf] + pip install .[test,sf] - name: Download examples run: | make download_examples From 90fa41970ca5289686f369b9a300b820a87bd1a1 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sat, 22 Jul 2023 21:36:19 +0200 Subject: [PATCH 14/18] separates windows and ubuntu tests in CI --- .github/workflows/test-ci.yml | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 48fa35a8..10a4ea53 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -7,11 +7,11 @@ on: branches: [ main ] jobs: - tests: + tests_ubuntu: strategy: matrix: python-version: [3.7, 3.8, 3.9, 3.10.10] - os: ['ubuntu-latest', 'windows-latest'] + os: ['ubuntu-latest'] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 @@ -31,3 +31,27 @@ jobs: - name: Test with pytest run: | make test + tests_windows: + strategy: + matrix: + python-version: [3.7, 3.8, 3.9, 3.10.10] + os: ['windows-latest'] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install .[test] + - name: Download examples + run: | + make download_examples + + - name: Test with pytest + run: | + make test From 70a656b514bbc124d6895763322d6c79f7883795 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sun, 23 Jul 2023 08:40:47 +0200 Subject: [PATCH 15/18] drops test support for python 3.7 --- .github/workflows/test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 10a4ea53..e0f8aa46 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -34,7 +34,7 @@ jobs: tests_windows: strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.10.10] + python-version: [3.8, 3.9, 3.10.10] os: ['windows-latest'] runs-on: ${{ matrix.os }} steps: From 885a9db29858d64c013eaad58de9785f5cc7c749 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sun, 23 Jul 2023 08:43:17 +0200 Subject: [PATCH 16/18] fixed imports for sf test --- tests/test_sample_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sample_factory.py b/tests/test_sample_factory.py index a6340bfa..871d549a 100644 --- a/tests/test_sample_factory.py +++ b/tests/test_sample_factory.py @@ -1,11 +1,11 @@ import pytest from godot_rl.core.utils import cant_import -from examples.sample_factory_example import get_args - + @pytest.mark.skipif(cant_import("sample_factory"), reason="sample_factory is not available") def test_sample_factory_training(): from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training + from examples.sample_factory_example import get_args args, extras = get_args() args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" extras = [] From 081d001ce8ab87e9f0756877de68f956e33a5b6e Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sun, 23 Jul 2023 08:59:46 +0200 Subject: [PATCH 17/18] drops 3.7 from ubuntu tests --- .github/workflows/test-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index e0f8aa46..5bc6933b 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -10,7 +10,7 @@ jobs: tests_ubuntu: strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.10.10] + python-version: [3.8, 3.9, 3.10.10] os: ['ubuntu-latest'] runs-on: ${{ matrix.os }} steps: From e6251f0c651fa1493c286aa24a6da0ed35a45b70 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sun, 23 Jul 2023 14:01:59 +0200 Subject: [PATCH 18/18] fixing sf tests for CI --- tests/test_sample_factory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_sample_factory.py b/tests/test_sample_factory.py index 871d549a..7f6b686d 100644 --- a/tests/test_sample_factory.py +++ b/tests/test_sample_factory.py @@ -10,6 +10,7 @@ def test_sample_factory_training(): args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" extras = [] extras.append('--env=gdrl') - extras.extend(['--train_for_env_steps=1000']) + extras.append('--train_for_env_steps=1000') + extras.append('--device=cpu') sample_factory_training(args, extras) \ No newline at end of file