diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index a3cb6c0b..1faa87ac 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -77,7 +77,7 @@ jobs: - name: Clean up dependencies run: | pip uninstall -y stable-baselines3 gymnasium - pip install .[rllib] + pip install ray[rllib] - name: Download examples run: | make download_examples @@ -105,7 +105,7 @@ jobs: - name: Clean up dependencies run: | pip uninstall -y stable-baselines3 gymnasium - pip install .[rllib] + pip install ray[rllib] - name: Download examples run: | make download_examples diff --git a/godot_rl/wrappers/onnx/stable_baselines_export.py b/godot_rl/wrappers/onnx/stable_baselines_export.py index 015a06fa..f39d0b32 100644 --- a/godot_rl/wrappers/onnx/stable_baselines_export.py +++ b/godot_rl/wrappers/onnx/stable_baselines_export.py @@ -59,7 +59,7 @@ def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10): onnx.checker.check_model(onnx_model) sb3_model = ppo.policy.to("cpu") - ort_sess = ort.InferenceSession(onnx_model_path) + ort_sess = ort.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) for i in range(num_tests): obs = dict(ppo.observation_space.sample()) diff --git a/godot_rl/wrappers/ray_wrapper.py b/godot_rl/wrappers/ray_wrapper.py index 01a6c195..e163fdd4 100644 --- a/godot_rl/wrappers/ray_wrapper.py +++ b/godot_rl/wrappers/ray_wrapper.py @@ -1,3 +1,4 @@ +import os import pathlib from typing import Callable, List, Optional, Tuple @@ -174,7 +175,7 @@ def rllib_training(args, extras): checkpoint_freq=checkpoint_freq, checkpoint_at_end=not args.eval, restore=args.restore, - local_dir=args.experiment_dir or "logs/rllib", + local_dir=os.path.abspath(args.experiment_dir) or os.path.abspath("logs/rllib"), trial_name_creator=lambda trial: f"{args.experiment_name}" if args.experiment_name else f"{trial.trainable_name}_{trial.trial_id}" ) if args.export: diff --git a/setup.cfg b/setup.cfg index a6dc219f..c0329fec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ install_requires = wget huggingface_hub>=0.10 gymnasium - stable-baselines3 + stable-baselines3>=2.0.0 huggingface_sb3 onnx onnxruntime @@ -48,7 +48,6 @@ sf = sample-factory rllib = - gymnasium==0.26.3 ray[rllib] cleanrl = diff --git a/tests/test_rllib.py b/tests/test_rllib.py index d6eb62f5..574d2b82 100644 --- a/tests/test_rllib.py +++ b/tests/test_rllib.py @@ -11,4 +11,4 @@ def test_rllib_training(): args.env_path = "examples/godot_rl_JumperHard/bin/JumperHard.x86_64" - rllib_training(args, extras) \ No newline at end of file + rllib_training(args, extras)