Skip to content

Commit

Permalink
Slowdown render fps to make the video nicer
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Feb 22, 2024
1 parent 9601117 commit 2f41bbd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 55 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
cheetah-accelerator>=0.5.18
gymnasium
imageio==2.4.1
ipywidgets
jupyterlab
Expand Down
58 changes: 7 additions & 51 deletions utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from datetime import datetime

import gym
import gymnasium as gym
import imageio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from gym.wrappers import RecordVideo
from gymnasium.wrappers import RecordVideo
from IPython import display
from moviepy.editor import VideoFileClip
from stable_baselines3 import PPO
Expand All @@ -29,13 +27,14 @@ def evaluate_ares_ea_agent(run_name, n=200, include_position=False):
err_sigma_xs = []
err_sigma_ys = []
for _ in range(n):
done = False
observation = env.reset()
terminated = False
truncated = False
observation, info = env.reset()
i = 0
while not done:
while not (terminated or truncated):
i += 1
action, _ = loaded_model.predict(observation)
observation, reward, done, info = env.step(action)
observation, reward, terminated, truncated, info = env.step(action)
err_sigma_xs.append(info["err_sigma_x"])
err_sigma_ys.append(info["err_sigma_y"])
step_indicies.append(i)
Expand Down Expand Up @@ -166,46 +165,3 @@ def record_video(env):

def show_video(filename):
return display.Video(filename)
"""
Wrapper for recording epsiode data such as observations, rewards, infos and actions.
"""

def __init__(self, env):
super().__init__(env)

self.has_previously_run = False

def reset(self):
if self.has_previously_run:
self.previous_observations = self.observations
self.previous_rewards = self.rewards
self.previous_infos = self.infos
self.previous_actions = self.actions
self.previous_t_start = self.t_start
self.previous_t_end = datetime.now()
self.previous_steps_taken = self.steps_taken

observation = self.env.reset()

self.observations = [observation]
self.rewards = []
self.infos = []
self.actions = []
self.t_start = datetime.now()
self.t_end = None
self.steps_taken = 0

self.has_previously_run = True

return observation

def step(self, action):
observation, reward, done, info = self.env.step(action)

self.observations.append(observation)
self.rewards.append(reward)
self.infos.append(info)
self.actions.append(action)
self.steps_taken += 1

return observation, reward, done, info
8 changes: 4 additions & 4 deletions utils/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Slightly modified version of ea_train.py to make it run in the workshop context.


import pathlib
import re
from functools import partial

import cheetah
Expand Down Expand Up @@ -183,7 +181,9 @@ def make_env(config, record_video=False, monitor_filename=None):
env = Monitor(env, filename=monitor_filename)
if record_video:
env = RecordVideo(
env, video_folder=f"utils/recordings/{config['wandb_run_name']}"
env,
video_folder=f"utils/recordings/{config['wandb_run_name']}",
disable_logger=True,
)
return env

Expand Down Expand Up @@ -212,7 +212,7 @@ class ARESEA(gym.Env):
Rendering mode. Choose from `"rgb_array"` or `"human"`.
"""

metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 2}
metadata = {"render.modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(
self,
Expand Down

0 comments on commit 2f41bbd

Please sign in to comment.