-
Notifications
You must be signed in to change notification settings - Fork 2
/
rl_scheduler.py
167 lines (138 loc) · 6.23 KB
/
rl_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Black-box ABM Calibration Kit (Black-it)
# Copyright (C) 2021-2024 Banca d'Italia
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""This module implements the 'RLScheduler' scheduler."""
from __future__ import annotations
import threading
from typing import TYPE_CHECKING, cast
import numpy as np
from black_it.samplers.base import BaseSampler
from black_it.samplers.halton import HaltonSampler
from black_it.schedulers.base import BaseScheduler
if TYPE_CHECKING:
from collections.abc import Sequence
from queue import Queue
from numpy._typing import NDArray
from black_it.schedulers.rl.agents.base import Agent
from black_it.schedulers.rl.envs.base import CalibrationEnv
class RLScheduler(BaseScheduler):
"""This class implement a RL-based scheduler.
It is agnostic wrt the RL algorithm being used.
"""
def __init__(
self,
samplers: Sequence[BaseSampler],
agent: Agent,
env: CalibrationEnv,
random_state: int | None = None,
) -> None:
"""Initialize the scheduler."""
self._original_samplers = samplers
new_samplers, self._halton_sampler_id = self._add_or_get_bootstrap_sampler(
samplers,
)
self._agent = agent
self._env = env
super().__init__(new_samplers, random_state)
self._in_queue: Queue = self._env._out_queue # noqa: SLF001
self._out_queue: Queue = self._env._in_queue # noqa: SLF001
self._best_param: float | None = None
self._best_loss: float | None = None
self._agent_thread: threading.Thread | None = None
self._stopped: bool = True
def _set_random_state(self, random_state: int | None) -> None:
"""Set the random state (private use)."""
super()._set_random_state(random_state)
for sampler in self.samplers:
sampler.random_state = self._get_random_seed()
self._agent.random_state = self._get_random_seed()
self._env.reset(seed=self._get_random_seed())
@classmethod
def _add_or_get_bootstrap_sampler(
cls,
samplers: Sequence[BaseSampler],
) -> tuple[Sequence[BaseSampler], int]:
"""Add or retrieve a sampler for bootstrapping.
Many samplers do require some "bootstrapping" of the calibration process, i.e. a set of parameters
whose loss has been already evaluated, e.g. samplers based on ML surrogates or on evolutionary approaches.
Therefore, this scheduler must guarantee that the first proposed sampler is one that does not need previous
model evaluations in input. One of such samplers is the Halton sampler
Therefore, this function checks that the HaltonSampler is present in the set of samplers. If so, it returns
the same set of samplers, and the index corresponding to that sampler in the sequence. Otherwise, a new
instance of HaltonSampler is added to the list as first element.
Args:
samplers: the list of available samplers
Returns:
The pair (new_samplers, halton_sampler_id).
"""
sampler_types = {type(s): i for i, s in enumerate(samplers)}
if HaltonSampler in sampler_types:
return samplers, sampler_types[HaltonSampler]
new_sampler = HaltonSampler(batch_size=1)
return tuple(list(samplers) + cast(list[BaseSampler], [new_sampler])), len(
samplers,
)
def _train(self) -> None:
"""Run the training loop."""
state = self._env.reset()
while not self._stopped:
# Get the action chosen by the agent
action = self._agent.policy(state)
# Interact with the environment
next_state, reward, _, _, _ = self._env.step(action)
# Learn from interaction
self._agent.learn(state, action, reward, next_state)
state = next_state
def start_session(self) -> None:
"""Set up the scheduler for a new session."""
if not self._stopped:
msg = "cannot start session: the session has already started"
raise ValueError(msg)
self._stopped = False
self._agent_thread = threading.Thread(target=self._train)
self._agent_thread.start()
def get_next_sampler(self) -> BaseSampler:
"""Get the next sampler."""
if self._best_loss is None:
# first call, return halton sampler
return self.samplers[self._halton_sampler_id]
chosen_sampler_id = self._in_queue.get()
return self.samplers[chosen_sampler_id]
def update(
self,
batch_id: int, # noqa: ARG002
new_params: NDArray[np.float64],
new_losses: NDArray[np.float64],
new_simulated_data: NDArray[np.float64], # noqa: ARG002
) -> None:
"""Update the RL scheduler."""
best_new_loss = float(np.min(new_losses))
if self._best_loss is None:
self._best_loss = best_new_loss
self._best_param = new_params[np.argmin(new_losses)]
self._env._curr_best_loss = best_new_loss # noqa: SLF001
return
if best_new_loss < cast(float, self._best_loss):
self._best_loss = best_new_loss
self._best_param = new_params[np.argmin(new_losses)]
self._out_queue.put((self._best_param, self._best_loss))
def end_session(self) -> None:
"""Tear down the scheduler at the end of the session."""
if self._stopped:
msg = "cannot start session: the session has not started yet"
raise ValueError(msg)
self._stopped = True
self._out_queue.put(None)
cast(threading.Thread, self._agent_thread).join()