Skip to content

Commit

Permalink
more code quality fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
edbeeching committed Jan 23, 2024
1 parent 39f677a commit ed8407f
Show file tree
Hide file tree
Showing 20 changed files with 102 additions and 107 deletions.
12 changes: 5 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
.PHONY: quality style test unity-test

check_dirs := tests godot_rl

# Format source code automatically
style:
black --line-length 120 --target-version py310 tests godot_rl
isort -w 120 tests godot_rl
black --line-length 120 --target-version py310 tests godot_rl examples
isort -w 120 tests godot_rl examples
# Check that source code meets quality standards
quality:
black --check --line-length 120 --target-version py310 tests godot_rl
isort -w 120 --check-only tests godot_rl
flake8 --max-line-length 120 tests godot_rl
black --check --line-length 120 --target-version py310 tests godot_rl examples
isort -w 120 --check-only tests godot_rl examples
flake8 --max-line-length 120 tests godot_rl examples

# Run tests for the library
test:
Expand Down
4 changes: 3 additions & 1 deletion examples/clean_rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import pathlib
import random
import time
from distutils.util import strtobool
from collections import deque
from distutils.util import strtobool

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

from godot_rl.wrappers.clean_rl_wrapper import CleanRLGodotEnv


Expand Down
3 changes: 2 additions & 1 deletion examples/sample_factory_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_training, sample_factory_enjoy

from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training


def get_args():
Expand Down
10 changes: 6 additions & 4 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pathlib
from typing import Callable

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor

from godot_rl.core.utils import can_import
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

# To download the env source and binary:
# 1. gdrl.env_from_hub -r edbeeching/godot_rl_BallChase
Expand Down Expand Up @@ -214,7 +215,8 @@ def func(progress_remaining: float) -> float:
model.learn(**learn_arguments)
except KeyboardInterrupt:
print(
"Training interrupted by user. Will save if --save_model_path was used and/or export if --onnx_export_path was used."
"""Training interrupted by user. Will save if --save_model_path was
used and/or export if --onnx_export_path was used."""
)

close_env()
Expand Down
20 changes: 8 additions & 12 deletions examples/stable_baselines3_hp_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,30 @@
You can run this example as follows:
$ python examples/stable_baselines3_hp_tuning.py --env_path=<path/to/your/env> --speedup=8 --n_parallel=1
Feel free to copy this script and update, add or remove the hp values to your liking.
Feel free to copy this script and update, add or remove the hp values to your liking.
"""

try:
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
except ImportError as e:
print(e)
print("You need to install optuna to use the hyperparameter tuning script. Try: pip install optuna")
exit()

from typing import Any
from typing import Dict
import argparse
from typing import Any, Dict

import gymnasium as gym

from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.core.godot_env import GodotEnv

import torch
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor

import torch
import torch.nn as nn

import argparse
from godot_rl.core.godot_env import GodotEnv
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
Expand Down
12 changes: 6 additions & 6 deletions godot_rl/core/godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ def check_platform(self, filename: str):
# Linux
assert (
pathlib.Path(filename).suffix == ".x86_64"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .x86_64 file"
), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .x86_64 file"
elif platform == "darwin":
# OSX
assert (
pathlib.Path(filename).suffix == ".app"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .app file"
), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .app file"
elif platform == "win32":
# Windows...
assert (
pathlib.Path(filename).suffix == ".exe"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .exe file"
), f"Incorrect file suffix for {filename=} {pathlib.Path(filename).suffix=}. Please provide a .exe file"
else:
assert 0, f"unknown filetype {pathlib.Path(filename).suffix}"

Expand All @@ -132,7 +132,7 @@ def from_numpy(self, action, order_ij=False):
env_action = {}

for j, k in enumerate(self._action_space.keys()):
if order_ij == True:
if order_ij is True:
v = action[i][j]
else:
v = action[j][i]
Expand Down Expand Up @@ -263,7 +263,7 @@ def _launch_env(self, env_path, port, show_window, framerate, seed, action_repea

launch_cmd = f"{path} --port={port} --env_seed={seed}"

if show_window == False:
if show_window is False:
launch_cmd += " --disable-render-loop --headless"
if framerate is not None:
launch_cmd += f" --fixed-fps {framerate}"
Expand Down Expand Up @@ -382,7 +382,7 @@ def _clear_socket(self):
data = self.connection.recv(4)
if not data:
break
except BlockingIOError as e:
except BlockingIOError:
pass
self.connection.setblocking(True)

Expand Down
2 changes: 1 addition & 1 deletion godot_rl/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def convert_macos_path(env_path):
"""

filenames = re.findall(r"[^\/]+(?=\.)", env_path)
assert len(filenames) == 1, f"An error occured while converting the env path for MacOS."
assert len(filenames) == 1, "An error occured while converting the env path for MacOS."
return env_path + "/Contents/MacOS/" + filenames[0]


Expand Down
15 changes: 7 additions & 8 deletions godot_rl/download_utils/download_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,35 @@

import os
import shutil
from sys import platform
from zipfile import ZipFile

import wget

BANCHES = {"4": "main", "3": "godot3.5"}
BRANCHES = {"4": "main", "3": "godot3.5"}

BASE_URL = "https://github.com/edbeeching/godot_rl_agents_examples"


def download_examples():
# select branch
print("Select Godot version:")
for key in BANCHES.keys():
print(f"{key} : {BANCHES[key]}")
for key in BRANCHES.keys():
print(f"{key} : {BRANCHES[key]}")

branch = input("Enter your choice: ")
BRANCH = BANCHES[branch]
BRANCH = BRANCHES[branch]
os.makedirs("examples", exist_ok=True)
URL = f"{BASE_URL}/archive/refs/heads/{BRANCH}.zip"
print(f"downloading examples from {URL}")
wget.download(URL, out="")
print()
print(f"unzipping")
print("unzipping")
with ZipFile(f"{BRANCH}.zip", "r") as zipObj:
# Extract all the contents of zip file in different directory
zipObj.extractall("examples/")
print(f"cleaning up")
print("cleaning up")
os.remove(f"{BRANCH}.zip")
print(f"moving files")
print("moving files")
for file in os.listdir(f"examples/godot_rl_agents_examples-{BRANCH}"):
shutil.move(f"examples/godot_rl_agents_examples-{BRANCH}/{file}", "examples")
os.rmdir(f"examples/godot_rl_agents_examples-{BRANCH}")
5 changes: 2 additions & 3 deletions godot_rl/download_utils/download_godot_editor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil
from sys import platform
from zipfile import ZipFile

Expand Down Expand Up @@ -50,9 +49,9 @@ def download_editor():
print(f"downloading editor {FILENAME} for platform: {platform}")
wget.download(URL, out="")
print()
print(f"unzipping")
print("unzipping")
with ZipFile(FILENAME, "r") as zipObj:
# Extract all the contents of zip file in different directory
zipObj.extractall("editor/")
print(f"cleaning up")
print("cleaning up")
os.remove(FILENAME)
2 changes: 1 addition & 1 deletion godot_rl/download_utils/from_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
parser.add_argument(
"-r",
"--hf_repository",
help="Repo id of the dataset / environment repository from the Hugging Face Hub in the form user_name/repo_name",
help="Repo id of the dataset / environment repo from the Hugging Face Hub in the form user_name/repo_name",
type=str,
)
parser.add_argument(
Expand Down
27 changes: 12 additions & 15 deletions godot_rl/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""
This is the main entrypoint to the Godot RL Agents interface
Example usage is best found in the documentation:
Example usage is best found in the documentation:
https://github.com/edbeeching/godot_rl_agents/blob/main/docs/EXAMPLE_ENVIRONMENTS.md
Hyperparameters and training algorithm can be defined in a .yaml file, see ppo_test.yaml as an example.
Interactive Training:
With the Godot editor open, type gdrl in the terminal to launch training and
With the Godot editor open, type gdrl in the terminal to launch training and
then press PLAY in the Godot editor. Training can be stopped with CTRL+C or
by pressing STOP in the editor.
Expand All @@ -25,34 +25,33 @@
try:
from godot_rl.wrappers.ray_wrapper import rllib_training
except ImportError as e:
error_message = str(e)

def rllib_training(args, extras):
print(
"Import error when trying to use rllib. If you have not installed the package, try: pip install godot-rl[rllib]"
)
print("Otherwise try fixing the error above.")
print("Import error importing rllib. If you have not installed the package, try: pip install godot-rl[rllib]")
print("Otherwise try fixing the error.", error_message)


try:
from godot_rl.wrappers.stable_baselines_wrapper import stable_baselines_training
except ImportError as e:
error_message = str(e)

def stable_baselines_training(args, extras):
print(
"Import error when trying to use sb3. If you have not installed the package, try: pip install godot-rl[sb3]"
)
print("Otherwise try fixing the error above.")
print("Import error importing sb3. If you have not installed the package, try: pip install godot-rl[sb3]")
print("Otherwise try fixing the error.", error_message)


try:
from godot_rl.wrappers.sample_factory_wrapper import sample_factory_enjoy, sample_factory_training
except ImportError as e:
error_message = str(e)

def sample_factory_training(args, extras):
print(
"Import error when trying to use sample-factory If you have not installed the package, try: pip install godot-rl[sf]"
"Import error importing sample-factory If you have not installed the package, try: pip install godot-rl[sf]"
)
print("Otherwise try fixing the error above.")
print("Otherwise try fixing the error.", error_message)


def get_args():
Expand Down Expand Up @@ -89,9 +88,7 @@ def get_args():
args.experiment_dir = f"logs/{args.trainer}"

if args.trainer == "sf" and args.env_path is None:
print(
"WARNING: the sample-factory intergration is not designed to run in interactive mode, please export you game to use this trainer"
)
print("WARNING: the sample-factory intergration is not designed to run in interactive mode, export you game")

return args, extras

Expand Down
2 changes: 1 addition & 1 deletion godot_rl/wrappers/clean_rl_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional

import gymnasium as gym
import numpy as np
Expand Down
Loading

0 comments on commit ed8407f

Please sign in to comment.