Skip to content

Commit

Permalink
Merge pull request #156 from edbeeching/fix-tests
Browse files Browse the repository at this point in the history
fixing broken tests on main
  • Loading branch information
edbeeching authored Nov 28, 2023
2 parents 523b0a7 + 871c244 commit a5363cb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 56 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.6.2"
authors = [
{ name="Edward Beeching", email="[email protected]" },
]
dynamic = ["license", "scripts", "dependencies", "optional-dependencies"]
description = "A Deep Reinforcement Learning package for the Godot game engine"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
76 changes: 22 additions & 54 deletions tests/test_godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,11 @@
@pytest.mark.parametrize(
"env_name,port,n_agents",
[
(
"BallChase",
12008,
16,
),
(
"FPS",
12009,
8,
),
(
"JumperHard",
12010,
16,
),
(
"Racer",
12011,
8,
),
(
"FlyBy",
12012,
16,
),
("BallChase", 12008, 16),
("FPS", 12009, 8),
("JumperHard", 12010, 16),
("Racer", 12011, 8),
("FlyBy", 12012, 16),
],
)
def test_env_ij(env_name, port, n_agents):
Expand All @@ -57,40 +37,24 @@ def test_env_ij(env_name, port, n_agents):
assert isinstance(
reward[0], (float, int)
), f"The reward returned by 'step()' must be a float or int, and is {reward[0]} of type {type(reward[0])}"
assert isinstance(term[0], bool), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean"
assert isinstance(info[0], dict), "The 'info' returned by 'step()' must be a python dictionary"
assert isinstance(
term[0], bool
), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean"
assert isinstance(
info[0], dict
), "The 'info' returned by 'step()' must be a python dictionary"

env.close()


@pytest.mark.parametrize(
"env_name,port,n_agents",
[
(
"BallChase",
13008,
16,
),
(
"FPS",
13009,
8,
),
(
"JumperHard",
13010,
16,
),
(
"Racer",
13011,
8,
),
(
"FlyBy",
13012,
16,
),
("BallChase", 13008, 16),
("FPS", 13009, 8),
("JumperHard", 13010, 16),
("Racer", 13011, 8),
("FlyBy", 13012, 16),
],
)
def test_env_ji(env_name, port, n_agents):
Expand Down Expand Up @@ -118,7 +82,11 @@ def test_env_ji(env_name, port, n_agents):
assert isinstance(
reward[0], (float, int)
), f"The reward returned by 'step()' must be a float or int, and is {reward[0]} of type {type(reward[0])}"
assert isinstance(term[0], bool), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean"
assert isinstance(info[0], dict), "The 'info' returned by 'step()' must be a python dictionary"
assert isinstance(
term[0], bool
), f"The 'done' signal {term[0]} {type(term[0])} must be a boolean"
assert isinstance(
info[0], dict
), "The 'info' returned by 'step()' must be a python dictionary"

env.close()
8 changes: 6 additions & 2 deletions tests/test_sb3_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from godot_rl.main import get_args
from godot_rl.core.utils import can_import


@pytest.mark.skipif(can_import("ray"), reason="rllib and sb3 are not compatable")
@pytest.mark.parametrize(
"env_name,port",
Expand All @@ -14,13 +15,16 @@
("FlyBy", 12400),
],
)
@pytest.mark.parametrize("n_parallel",[1,2,4])
@pytest.mark.parametrize("n_parallel", [1, 2, 4])
def test_sb3_training(env_name, port, n_parallel):
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training

args, extras = get_args()
args.env = "gdrl"
args.env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
args.experiment_name = f"test_{env_name}_{n_parallel}"
starting_port = port + n_parallel

stable_baselines_training(args, extras, n_steps=10, port=starting_port, n_parallel=n_parallel)
stable_baselines_training(
args, extras, n_steps=2, port=starting_port, n_parallel=n_parallel
)

0 comments on commit a5363cb

Please sign in to comment.