Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds resuming and saving model to sb3 example #135

Merged
merged 10 commits into from
Jul 23, 2023
99 changes: 91 additions & 8 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import argparse
import os
import pathlib

from stable_baselines3.common.callbacks import CheckpointCallback
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from stable_baselines3 import PPO
Expand All @@ -21,34 +24,114 @@
"--experiment_dir",
default="logs/sb3",
type=str,
help="The name of the experiment directory, in which the tensorboard logs are getting stored",
help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are getting stored."
Copy link
Owner

@edbeeching edbeeching Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit, can the default experiment_name be lower case, "experiment"?

)
parser.add_argument(
"--experiment_name",
default="Experiment",
type=str,
help="The name of the experiment, which will be displayed in tensorboard",
help="The name of the experiment, which will be displayed in tensorboard and "
"for checkpoint directory and name (if enabled).",
)
parser.add_argument(
"--resume_model_path",
default=None,
type=str,
help="The path to a model file previously saved using --save_model_path or a checkpoint saved using "
"--save_checkpoints_frequency. Use this to resume training from a saved model.",
Ivan-267 marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
"--save_model_path",
default=None,
type=str,
help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later "
"to resume training. Extension will be set to .zip",
)
parser.add_argument(
"--save_checkpoint_frequency",
default=None,
type=int,
help=("If set, will save checkpoints every 'frequency' environment steps. "
"Requires a unique --experiment_name or --experiment_dir for each run. "
"Does not need --save_model_path to be set. "),
)
parser.add_argument(
"--onnx_export_path",
default=None,
type=str,
help="The Godot binary to use, do not include for in editor training",
)

parser.add_argument(
"--timesteps",
default=1_000_000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh nice!

type=int,
help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, "
"it will continue training for this amount of steps from the saved state without counting previously trained "
"steps",
)
parser.add_argument(
"--inference",
Ivan-267 marked this conversation as resolved.
Show resolved Hide resolved
action="store_true",
help="Instead of training, it will run inference on a loaded model for --timesteps steps. "
"Requires --resume_model_path to be set."
)
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to launch - requires --env_path to be set if > 1.")
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to "
"launch - requires --env_path to be set if > 1.")
args, extras = parser.parse_known_args()

path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)

# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
raise RuntimeError(abs_path_checkpoint + " folder already exists. "
"Use a different --experiment_dir, or --experiment_name,"
"or if previous checkpoints are not needed anymore, "
"remove the folder containing the checkpoints. ")

if args.inference and args.resume_model_path is None:
raise parser.error("Using --inference requires --resume_model_path to be set.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice even with proper error handling. ❤️


env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel, speedup=args.speedup)
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel,
speedup=args.speedup)
env = VecMonitor(env)

model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir)
model.learn(1000000, tb_log_name=args.experiment_name)
if args.resume_model_path is None:
model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir)
else:
path_zip = pathlib.Path(args.resume_model_path)
print("Loading model: " + os.path.abspath(path_zip))
model = PPO.load(path_zip, env=env, tensorboard_log=args.experiment_dir)

if args.inference:
obs = env.reset()
for i in range(args.timesteps):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
if args.save_checkpoint_frequency is None:
model.learn(args.timesteps, tb_log_name=args.experiment_name)
else:
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint)
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=args.experiment_name
)
model.learn(args.timesteps, callback=checkpoint_callback, tb_log_name=args.experiment_name)

print("closing env")
env.close()

# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
export_ppo_model_as_onnx(model, args.onnx_export_path)
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))

if args.save_model_path is not None:
path_zip = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(path_zip))
model.save(path_zip)