Skip to content

Commit

Permalink
Update package versions and notebook to latest versions
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Feb 22, 2024
1 parent c6e188b commit 9601117
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 123 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ opencv-python
seaborn
pyyaml
RISE
stable-baselines3==1.6.0
stable-baselines3

122 changes: 63 additions & 59 deletions tutorial.ipynb

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions utils/ares_ea_lattice.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"version": "cheetah-0.6",
"title": "unnamed",
"info": "This is a placeholder lattice description",
"root": "unnamed",
"elements": {
"AREASOLA1": ["Marker", {}],
"Drift_AREASOLA1": ["Drift", {"length": 0.17504000663757324}],
"AREAMQZM1": ["Quadrupole", {"length": 0.12200000137090683, "k1": 0.0, "misalignment": [0.0, 0.0], "tilt": 0.0}],
"Drift_AREAMQZM1": ["Drift", {"length": 0.42800000309944153}],
"AREAMQZM2": ["Quadrupole", {"length": 0.12200000137090683, "k1": 0.0, "misalignment": [0.0, 0.0], "tilt": 0.0}],
"Drift_AREAMQZM2": ["Drift", {"length": 0.20399999618530273}],
"AREAMCVM1": ["VerticalCorrector", {"length": 0.019999999552965164, "angle": 0.0}],
"Drift_AREAMCVM1": ["Drift", {"length": 0.20399999618530273}],
"AREAMQZM3": ["Quadrupole", {"length": 0.12200000137090683, "k1": 0.0, "misalignment": [0.0, 0.0], "tilt": 0.0}],
"Drift_AREAMQZM3": ["Drift", {"length": 0.17900000512599945}],
"AREAMCHM1": ["HorizontalCorrector", {"length": 0.019999999552965164, "angle": 0.0}],
"Drift_AREAMCHM1": ["Drift", {"length": 0.44999998807907104}],
"AREABSCR1": ["Screen", {"resolution": [2448, 2040], "pixel_size": [3.3198000437550945e-06, 2.4468999981763773e-06], "binning": 1, "misalignment": [0.0, 0.0], "is_active": true}]
},
"lattices": {
"unnamed": ["AREASOLA1", "Drift_AREASOLA1", "AREAMQZM1", "Drift_AREAMQZM1", "AREAMQZM2", "Drift_AREAMQZM2", "AREAMCVM1", "Drift_AREAMCVM1", "AREAMQZM3", "Drift_AREAMQZM3", "AREAMCHM1", "Drift_AREAMCHM1", "AREABSCR1"]
}
}
Binary file removed utils/lattice.pkl
Binary file not shown.
115 changes: 68 additions & 47 deletions utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@


import pathlib
import pickle
import re
from functools import partial

import cheetah
import cv2
import gym
import gymnasium as gym
import numpy as np
import torch
import yaml
from gym import spaces
from gym.wrappers import (
from gymnasium import spaces
from gymnasium.wrappers import (
FilterObservation,
FlattenObservation,
FrameStack,
Expand All @@ -20,6 +21,7 @@
TimeLimit,
)
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize

Expand Down Expand Up @@ -132,10 +134,11 @@ def train(config):
verbose=0,
)

eval_callback = EvalCallback(eval_env=eval_env, eval_freq=500, n_eval_episodes=5)

model.learn(
total_timesteps=config["total_timesteps"],
eval_env=eval_env,
eval_freq=500,
callback=[eval_callback],
)

model.save(f"utils/models/{run_name}/model")
Expand Down Expand Up @@ -205,9 +208,11 @@ class ARESEA(gym.Env):
target_beam_mode : str
Setting of target beam on `reset`. Choose from `"constant"` or `"random"`. The
`"constant"` setting requires `target_beam_values` to be set.
render_mode : str
Rendering mode. Choose from `"rgb_array"` or `"human"`.
"""

metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 2}
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 2}

def __init__(
self,
Expand All @@ -225,6 +230,7 @@ def __init__(
target_sigma_y_threshold=2.4469e-6,
threshold_hold=1,
time_reward=-0.0,
render_mode="rgb_array",
):
self.abort_if_off_screen = abort_if_off_screen
self.action_mode = action_mode
Expand All @@ -240,6 +246,7 @@ def __init__(
self.target_sigma_y_threshold = target_sigma_y_threshold
self.threshold_hold = threshold_hold
self.time_reward = time_reward
self.render_mode = render_mode

# Create action space
if self.action_mode == "direct":
Expand Down Expand Up @@ -268,11 +275,15 @@ def __init__(
low=np.array([-np.inf, 0, -np.inf, 0], dtype=np.float32),
high=np.array([np.inf, np.inf, np.inf, np.inf], dtype=np.float32),
),
"magnets": self.action_space
if self.action_mode.startswith("direct")
else spaces.Box(
low=np.array([-72, -72, -6.1782e-3, -72, -6.1782e-3], dtype=np.float32),
high=np.array([72, 72, 6.1782e-3, 72, 6.1782e-3], dtype=np.float32),
"magnets": (
self.action_space
if self.action_mode.startswith("direct")
else spaces.Box(
low=np.array(
[-72, -72, -6.1782e-3, -72, -6.1782e-3], dtype=np.float32
),
high=np.array([72, 72, 6.1782e-3, 72, 6.1782e-3], dtype=np.float32),
)
),
"target": spaces.Box(
low=np.array([-np.inf, 0, -np.inf, 0], dtype=np.float32),
Expand All @@ -285,7 +296,7 @@ def __init__(
# Setup the accelerator (either simulation or the actual machine)
self.setup_accelerator()

def reset(self):
def reset(self, **kwargs):
self.reset_accelerator()

if self.magnet_init_mode == "constant":
Expand Down Expand Up @@ -323,7 +334,7 @@ def reset(self):
}
observation.update(self.get_accelerator_observation())

return observation
return observation, {}

def step(self, action):
# Perform action
Expand Down Expand Up @@ -417,9 +428,10 @@ def step(self, action):

self.previous_beam = current_beam

return observation, reward, done, info
return observation, reward, done, False, info

def render(self, mode="human"):
def render(self):
mode = self.render_mode
assert mode == "rgb_array" or mode == "human"

binning = self.get_binning()
Expand Down Expand Up @@ -720,14 +732,16 @@ class ARESEACheetah(ARESEA):
def __init__(
self,
incoming_mode="constant",
incoming_values=np.array([80e6,-5e-4,6e-5,3e-4,3e-5,4e-4,4e-5,1e-4,4e-5,0,4e-6]),
incoming_values=np.array(
[80e6, -5e-4, 6e-5, 3e-4, 3e-5, 4e-4, 4e-5, 1e-4, 4e-5, 0, 4e-6]
),
misalignment_mode="constant",
misalignment_values=np.zeros(8),
abort_if_off_screen=False,
action_mode="delta",
include_screen_image_in_info=False,
magnet_init_mode="constant",
magnet_init_values=[10,-10,0,10,0],
magnet_init_values=[10, -10, 0, 10, 0],
reward_mode="negative_objective",
target_beam_mode="random",
target_beam_values=None,
Expand Down Expand Up @@ -761,17 +775,23 @@ def __init__(
self.misalignment_values = misalignment_values

# Create particle simulation
with open("utils/lattice.pkl", "rb") as f:
self.simulation = pickle.load(f)
# with open("utils/lattice.pkl", "rb") as f:
# self.simulation = pickle.load(f)
self.simulation = cheetah.Segment.from_lattice_json(
"utils/ares_ea_lattice.json"
)

def is_beam_on_screen(self):
screen = self.simulation.AREABSCR1
beam_position = np.array([screen.read_beam.mu_x, screen.read_beam.mu_y])
out_beam = screen.get_read_beam()
beam_position = np.array([out_beam.mu_x, out_beam.mu_y])
limits = np.array(screen.resolution) / 2 * np.array(screen.pixel_size)
extended_limits = (
limits + np.array([screen.read_beam.sigma_x, screen.read_beam.sigma_y]) * 2
)
return np.all(np.abs(beam_position) < extended_limits)
return np.all(np.abs(beam_position) < limits)
# limits = np.array(screen.resolution) / 2 * np.array(screen.pixel_size)
# extended_limits = (
# limits + np.array([screen.read_beam.sigma_x, screen.read_beam.sigma_y]) * 2
# )
# return np.all(np.abs(beam_position) < extended_limits)

def get_magnets(self):
return np.array(
Expand All @@ -785,11 +805,11 @@ def get_magnets(self):
)

def set_magnets(self, magnets):
self.simulation.AREAMQZM1.k1 = magnets[0]
self.simulation.AREAMQZM2.k1 = magnets[1]
self.simulation.AREAMCVM1.angle = magnets[2]
self.simulation.AREAMQZM3.k1 = magnets[3]
self.simulation.AREAMCHM1.angle = magnets[4]
self.simulation.AREAMQZM1.k1 = torch.tensor(magnets[0], dtype=torch.float32)
self.simulation.AREAMQZM2.k1 = torch.tensor(magnets[1], dtype=torch.float32)
self.simulation.AREAMCVM1.angle = torch.tensor(magnets[2], dtype=torch.float32)
self.simulation.AREAMQZM3.k1 = torch.tensor(magnets[3], dtype=torch.float32)
self.simulation.AREAMCHM1.angle = torch.tensor(magnets[4], dtype=torch.float32)

def reset_accelerator(self):
# New domain randomisation
Expand All @@ -800,17 +820,17 @@ def reset_accelerator(self):
else:
raise ValueError(f'Invalid value "{self.incoming_mode}" for incoming_mode')
self.incoming = cheetah.ParameterBeam.from_parameters(
energy=incoming_parameters[0],
mu_x=incoming_parameters[1],
mu_xp=incoming_parameters[2],
mu_y=incoming_parameters[3],
mu_yp=incoming_parameters[4],
sigma_x=incoming_parameters[5],
sigma_xp=incoming_parameters[6],
sigma_y=incoming_parameters[7],
sigma_yp=incoming_parameters[8],
sigma_s=incoming_parameters[9],
sigma_p=incoming_parameters[10],
energy=torch.tensor(incoming_parameters[0]),
mu_x=torch.tensor(incoming_parameters[1]),
mu_xp=torch.tensor(incoming_parameters[2]),
mu_y=torch.tensor(incoming_parameters[3]),
mu_yp=torch.tensor(incoming_parameters[4]),
sigma_x=torch.tensor(incoming_parameters[5]),
sigma_xp=torch.tensor(incoming_parameters[6]),
sigma_y=torch.tensor(incoming_parameters[7]),
sigma_yp=torch.tensor(incoming_parameters[8]),
sigma_s=torch.tensor(incoming_parameters[9]),
sigma_p=torch.tensor(incoming_parameters[10]),
)

if self.misalignment_mode == "constant":
Expand All @@ -827,15 +847,16 @@ def reset_accelerator(self):
self.simulation.AREABSCR1.misalignment = misalignments[6:8]

def update_accelerator(self):
self.simulation(self.incoming)
self.simulation.track(self.incoming)

def get_beam_parameters(self):
out_beam = self.simulation.AREABSCR1.get_read_beam()
return np.array(
[
self.simulation.AREABSCR1.read_beam.mu_x,
self.simulation.AREABSCR1.read_beam.sigma_x,
self.simulation.AREABSCR1.read_beam.mu_y,
self.simulation.AREABSCR1.read_beam.sigma_y,
out_beam.mu_x,
out_beam.sigma_x,
out_beam.mu_y,
out_beam.sigma_y,
]
)

Expand Down Expand Up @@ -876,7 +897,7 @@ def get_misalignments(self):
def get_screen_image(self):
# Beam image to look like real image by dividing by goodlooking number and
# scaling to 12 bits
return self.simulation.AREABSCR1.reading / 1e9 * 2**12
return (self.simulation.AREABSCR1.reading).numpy() / 1e9 * 2**12

def get_binning(self):
return np.array(self.simulation.AREABSCR1.binning)
Expand Down
33 changes: 17 additions & 16 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pickle
from datetime import datetime

import gym
import gymnasium as gym
import numpy as np
from gym import spaces
from gymnasium import spaces


class FilterAction(gym.ActionWrapper):
Expand Down Expand Up @@ -47,15 +47,15 @@ def __init__(self, env, path):
with open(path, "rb") as file_handler:
self.vec_normalize = pickle.load(file_handler)

def reset(self):
observation = self.env.reset()
return self.vec_normalize.normalize_obs(observation)
def reset(self, **kwargs):
observation, info = self.env.reset(**kwargs)
return self.vec_normalize.normalize_obs(observation), info

def step(self, action):
observation, reward, done, info = self.env.step(action)
observation, reward, done, truncated, info = self.env.step(action)
observation = self.vec_normalize.normalize_obs(observation)
reward = self.vec_normalize.normalize_reward(reward)
return observation, reward, done, info
return observation, reward, done, truncated, info


class PolishedDonkeyCompatibility(gym.Wrapper):
Expand Down Expand Up @@ -104,12 +104,13 @@ def __init__(self, env):
high=np.array([30, 30, 30, 3e-3, 6e-3], dtype=np.float32) * 0.1,
)

def reset(self):
return self.observation(super().reset())
def reset(self, **kwargs):
observation, info = super().reset(**kwargs)
return self.observation(observation), info

def step(self, action):
observation, reward, done, info = super().step(self.action(action))
return self.observation(observation), reward, done, info
observation, reward, done, truncated, info = super().step(self.action(action))
return self.observation(observation), reward, done, truncated, info

def observation(self, observation):
return np.array(
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(self, env):

self.has_previously_run = False

def reset(self):
def reset(self, **kwargs):
if self.has_previously_run:
self.previous_observations = self.observations
self.previous_rewards = self.rewards
Expand All @@ -162,7 +163,7 @@ def reset(self):
self.previous_t_end = datetime.now()
self.previous_steps_taken = self.steps_taken

observation = self.env.reset()
observation, info = self.env.reset(**kwargs)

self.observations = [observation]
self.rewards = []
Expand All @@ -174,15 +175,15 @@ def reset(self):

self.has_previously_run = True

return observation
return observation, info

def step(self, action):
observation, reward, done, info = self.env.step(action)
observation, reward, done, truncated, info = self.env.step(action)

self.observations.append(observation)
self.rewards.append(reward)
self.infos.append(info)
self.actions.append(action)
self.steps_taken += 1

return observation, reward, done, info
return observation, reward, done, truncated, info

0 comments on commit 9601117

Please sign in to comment.