Skip to content

Commit

Permalink
Merge pull request #135 from edbeeching/sb3_example_add_save_and_resume
Browse files Browse the repository at this point in the history
Adds resuming and saving model, inference and viz CL arguments to sb3 example
  • Loading branch information
Ivan-267 authored Jul 23, 2023
2 parents f90378a + 82bd742 commit 2c348c8
Showing 1 changed file with 103 additions and 8 deletions.
111 changes: 103 additions & 8 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import argparse
import os
import pathlib

from stable_baselines3.common.callbacks import CheckpointCallback
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from stable_baselines3 import PPO
Expand All @@ -21,34 +24,126 @@
"--experiment_dir",
default="logs/sb3",
type=str,
help="The name of the experiment directory, in which the tensorboard logs are getting stored",
help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are "
"getting stored."
)
parser.add_argument(
"--experiment_name",
default="experiment",
type=str,
help="The name of the experiment, which will be displayed in tensorboard",
help="The name of the experiment, which will be displayed in tensorboard and "
"for checkpoint directory and name (if enabled).",
)
parser.add_argument(
"--resume_model_path",
default=None,
type=str,
help="The path to a model file previously saved using --save_model_path or a checkpoint saved using "
"--save_checkpoints_frequency. Use this to resume training or infer from a saved model.",
)
parser.add_argument(
"--save_model_path",
default=None,
type=str,
help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later "
"to resume training. Extension will be set to .zip",
)
parser.add_argument(
"--save_checkpoint_frequency",
default=None,
type=int,
help=("If set, will save checkpoints every 'frequency' environment steps. "
"Requires a unique --experiment_name or --experiment_dir for each run. "
"Does not need --save_model_path to be set. "),
)
parser.add_argument(
"--onnx_export_path",
default=None,
type=str,
help="The Godot binary to use, do not include for in editor training",
)

parser.add_argument(
"--timesteps",
default=1_000_000,
type=int,
help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, "
"it will continue training for this amount of steps from the saved state without counting previously trained "
"steps",
)
parser.add_argument(
"--inference",
default=False,
action="store_true",
help="Instead of training, it will run inference on a loaded model for --timesteps steps. "
"Requires --resume_model_path to be set."
)
parser.add_argument(
"--viz",
action="store_true",
help="If set, the window(s) with the Godot environment(s) will be displayed, otherwise "
"training will run without rendering the game. Does not apply to in-editor training.",
default=False
)
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env")
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to launch - requires --env_path to be set if > 1.")
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to "
"launch - requires --env_path to be set if > 1.")
args, extras = parser.parse_known_args()

path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints")
abs_path_checkpoint = os.path.abspath(path_checkpoint)

env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel, speedup=args.speedup)
# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint):
raise RuntimeError(abs_path_checkpoint + " folder already exists. "
"Use a different --experiment_dir, or --experiment_name,"
"or if previous checkpoints are not needed anymore, "
"remove the folder containing the checkpoints. ")

if args.inference and args.resume_model_path is None:
raise parser.error("Using --inference requires --resume_model_path to be set.")

if args.env_path is None and args.viz:
print("Info: Using --viz without --env_path set has no effect, in-editor training will always render.")

env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=args.viz, n_parallel=args.n_parallel,
speedup=args.speedup)
env = VecMonitor(env)

model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir)
model.learn(1000000, tb_log_name=args.experiment_name)
if args.resume_model_path is None:
model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir)
else:
path_zip = pathlib.Path(args.resume_model_path)
print("Loading model: " + os.path.abspath(path_zip))
model = PPO.load(path_zip, env=env, tensorboard_log=args.experiment_dir)

if args.inference:
obs = env.reset()
for i in range(args.timesteps):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
else:
if args.save_checkpoint_frequency is None:
model.learn(args.timesteps, tb_log_name=args.experiment_name)
else:
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint)
checkpoint_callback = CheckpointCallback(
save_freq=(args.save_checkpoint_frequency // env.num_envs),
save_path=path_checkpoint,
name_prefix=args.experiment_name
)
model.learn(args.timesteps, callback=checkpoint_callback, tb_log_name=args.experiment_name)

print("closing env")
env.close()

# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name
# and extension used for both
if args.onnx_export_path is not None:
export_ppo_model_as_onnx(model, args.onnx_export_path)
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))

if args.save_model_path is not None:
path_zip = pathlib.Path(args.save_model_path).with_suffix(".zip")
print("Saving model to: " + os.path.abspath(path_zip))
model.save(path_zip)

0 comments on commit 2c348c8

Please sign in to comment.