-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds resuming and saving model to sb3 example #135
Merged
Merged
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
25c9514
Adds --resume_model_path and --save_model_path
Ivan-267 10d8130
Adds timesteps and implements saving
Ivan-267 7b86692
Adds auto-checkpoint saving and inference
Ivan-267 602147b
CL args help text update
Ivan-267 7792893
Added error message when using inference without resume_model_path
Ivan-267 e6e6214
Removes a left-over print from testing
Ivan-267 537fdee
Adds infer to resume training description
Ivan-267 91c4e01
Default experiment name changed to lowercase
Ivan-267 52651e0
Adds --viz argument for changing rendering mode
Ivan-267 82bd742
Add default=False to inference
Ivan-267 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
import argparse | ||
import os | ||
import pathlib | ||
|
||
from stable_baselines3.common.callbacks import CheckpointCallback | ||
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv | ||
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx | ||
from stable_baselines3 import PPO | ||
|
@@ -21,34 +24,114 @@ | |
"--experiment_dir", | ||
default="logs/sb3", | ||
type=str, | ||
help="The name of the experiment directory, in which the tensorboard logs are getting stored", | ||
help="The name of the experiment directory, in which the tensorboard logs and checkpoints (if enabled) are getting stored." | ||
) | ||
parser.add_argument( | ||
"--experiment_name", | ||
default="Experiment", | ||
type=str, | ||
help="The name of the experiment, which will be displayed in tensorboard", | ||
help="The name of the experiment, which will be displayed in tensorboard and " | ||
"for checkpoint directory and name (if enabled).", | ||
) | ||
parser.add_argument( | ||
"--resume_model_path", | ||
default=None, | ||
type=str, | ||
help="The path to a model file previously saved using --save_model_path or a checkpoint saved using " | ||
"--save_checkpoints_frequency. Use this to resume training or infer from a saved model.", | ||
) | ||
parser.add_argument( | ||
"--save_model_path", | ||
default=None, | ||
type=str, | ||
help="The path to use for saving the trained sb3 model after training is complete. Saved model can be used later " | ||
"to resume training. Extension will be set to .zip", | ||
) | ||
parser.add_argument( | ||
"--save_checkpoint_frequency", | ||
default=None, | ||
type=int, | ||
help=("If set, will save checkpoints every 'frequency' environment steps. " | ||
"Requires a unique --experiment_name or --experiment_dir for each run. " | ||
"Does not need --save_model_path to be set. "), | ||
) | ||
parser.add_argument( | ||
"--onnx_export_path", | ||
default=None, | ||
type=str, | ||
help="The Godot binary to use, do not include for in editor training", | ||
) | ||
|
||
parser.add_argument( | ||
"--timesteps", | ||
default=1_000_000, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh nice! |
||
type=int, | ||
help="The number of environment steps to train for, default is 1_000_000. If resuming from a saved model, " | ||
"it will continue training for this amount of steps from the saved state without counting previously trained " | ||
"steps", | ||
) | ||
parser.add_argument( | ||
"--inference", | ||
Ivan-267 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
action="store_true", | ||
help="Instead of training, it will run inference on a loaded model for --timesteps steps. " | ||
"Requires --resume_model_path to be set." | ||
) | ||
parser.add_argument("--speedup", default=1, type=int, help="Whether to speed up the physics in the env") | ||
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to launch - requires --env_path to be set if > 1.") | ||
parser.add_argument("--n_parallel", default=1, type=int, help="How many instances of the environment executable to " | ||
"launch - requires --env_path to be set if > 1.") | ||
args, extras = parser.parse_known_args() | ||
|
||
path_checkpoint = os.path.join(args.experiment_dir, args.experiment_name + "_checkpoints") | ||
abs_path_checkpoint = os.path.abspath(path_checkpoint) | ||
|
||
# Prevent overwriting existing checkpoints when starting a new experiment if checkpoint saving is enabled | ||
if args.save_checkpoint_frequency is not None and os.path.isdir(path_checkpoint): | ||
raise RuntimeError(abs_path_checkpoint + " folder already exists. " | ||
"Use a different --experiment_dir, or --experiment_name," | ||
"or if previous checkpoints are not needed anymore, " | ||
"remove the folder containing the checkpoints. ") | ||
|
||
if args.inference and args.resume_model_path is None: | ||
raise parser.error("Using --inference requires --resume_model_path to be set.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice even with proper error handling. ❤️ |
||
|
||
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel, speedup=args.speedup) | ||
env = StableBaselinesGodotEnv(env_path=args.env_path, show_window=True, n_parallel=args.n_parallel, | ||
speedup=args.speedup) | ||
env = VecMonitor(env) | ||
|
||
model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir) | ||
model.learn(1000000, tb_log_name=args.experiment_name) | ||
if args.resume_model_path is None: | ||
model = PPO("MultiInputPolicy", env, ent_coef=0.0001, verbose=2, n_steps=32, tensorboard_log=args.experiment_dir) | ||
else: | ||
path_zip = pathlib.Path(args.resume_model_path) | ||
print("Loading model: " + os.path.abspath(path_zip)) | ||
model = PPO.load(path_zip, env=env, tensorboard_log=args.experiment_dir) | ||
|
||
if args.inference: | ||
obs = env.reset() | ||
for i in range(args.timesteps): | ||
action, _state = model.predict(obs, deterministic=True) | ||
obs, reward, done, info = env.step(action) | ||
else: | ||
if args.save_checkpoint_frequency is None: | ||
model.learn(args.timesteps, tb_log_name=args.experiment_name) | ||
else: | ||
print("Checkpoint saving enabled. Checkpoints will be saved to: " + abs_path_checkpoint) | ||
checkpoint_callback = CheckpointCallback( | ||
save_freq=(args.save_checkpoint_frequency // env.num_envs), | ||
save_path=path_checkpoint, | ||
name_prefix=args.experiment_name | ||
) | ||
model.learn(args.timesteps, callback=checkpoint_callback, tb_log_name=args.experiment_name) | ||
|
||
print("closing env") | ||
env.close() | ||
|
||
# Enforce the extension of onnx and zip when saving model to avoid potential conflicts in case of same name | ||
# and extension used for both | ||
if args.onnx_export_path is not None: | ||
export_ppo_model_as_onnx(model, args.onnx_export_path) | ||
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx") | ||
print("Exporting onnx to: " + os.path.abspath(path_onnx)) | ||
export_ppo_model_as_onnx(model, str(path_onnx)) | ||
|
||
if args.save_model_path is not None: | ||
path_zip = pathlib.Path(args.save_model_path).with_suffix(".zip") | ||
print("Saving model to: " + os.path.abspath(path_zip)) | ||
model.save(path_zip) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nit, can the default
experiment_name
be lower case, "experiment"?