-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce BaseScheduler
abstraction
#52
Conversation
self.scheduler.random_state = self.random_state | ||
|
||
# "burn" seeds from the calibrator seed generator for backward compatibility | ||
for _ in self.scheduler.samplers: | ||
self._get_random_seed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really to show how the proposed change could, in theory, preserve the same behaviour.
Clearly "burning random seeds" is a bit odd thing to do.
Moreover, self.scheduler.random_state = self.random_state
should be changed to self.scheduler.random_state = self._get_random_seed()
, which necessarily detours from the old (exactly reproducible) behaviour.
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## main #52 +/- ##
==========================================
+ Coverage 96.86% 96.92% +0.06%
==========================================
Files 31 34 +3
Lines 1499 1563 +64
==========================================
+ Hits 1452 1515 +63
- Misses 47 48 +1
|
# overwrite the list of samplers | ||
self.samplers = samplers | ||
self.scheduler._samplers = tuple(samplers) # pylint: disable=protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be illegal, and we inherit it from the not-so-clean-design of including samplers_id_table
into calibrator. Anyway, topic for another PR, let's put this feature in first while keeping the rest to work as usual.
@@ -476,7 +500,7 @@ def create_checkpoint(self, file_name: Union[str, os.PathLike]) -> None: | |||
self.random_state, | |||
self.random_generator.bit_generator.state, | |||
model_name, | |||
self.samplers, | |||
self.scheduler, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
breaking change: we now pickle scheduler
rather than samplers
.
@@ -107,7 +107,7 @@ def save_calibrator_state( # pylint: disable=too-many-arguments,too-many-locals | |||
initial_random_seed: Optional[int], | |||
random_generator_state: Mapping, | |||
model_name: str, | |||
samplers: Sequence[BaseSampler], | |||
scheduler: BaseScheduler, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
breaking change on checkpointing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO update sqlite checkpointing as well
@@ -174,7 +174,7 @@ def save_calibrator_state( # pylint: disable=too-many-arguments,too-many-locals | |||
|
|||
# save instantiated samplers and loss functions | |||
with open(checkpoint_path / "samplers_pickled.pickle", "wb") as fb: | |||
pickle.dump(samplers, fb) | |||
pickle.dump(scheduler, fb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should samplers_pickle
be renamed? maybe not...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably yes actually, I'll take care of that
tests/test_calibrator.py
Outdated
@@ -222,7 +231,7 @@ def test_calibrator_restore_from_checkpoint_and_set_sampler() -> None: | |||
type(vars_cal["param_grid"]).__name__ | |||
== type(cal_restored.param_grid).__name__ # noqa | |||
) | |||
elif key == "_random_generator": | |||
elif key == f"_{BaseSeedable.__name__}__random_generator": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name mangling... duh!
f64738f
to
adf42bb
Compare
…ility In terms of functionality, we only had to change the number of batches, since nnow each batch only runs one sampler. This commit will be probably amended and/or splitted in smaller commits.
new_simulated_data: NDArray[np.float64], | ||
) -> None: | ||
"""Update the state of the scheduler after each batch.""" | ||
self._batch_id += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcofavoritobi what about setting this update function to be the default (non abstract) method for all schedulers? What's your opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. Actually, would be a first step toward a "shared (read-only) calibration state object" that is accessible both from each sampler and from the scheduler.
b5dcde5
to
4b10e30
Compare
…' to 'scheduler' in json checkpointing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! A very useful feature to extend the flexibility of the package in the adaptive selection of specific samplers!
Proposed changes
Introduce
BaseScheduler
abstraction, withRoundRobinScheduler
as default scheduler forCalibrator
.Fixes
n/a
Types of changes
What types of changes does your code introduce?
Put an
x
in the boxes that applyChecklist
Put an
x
in the boxes that apply.main
branch (left side). Also you should start your branch off ourmain
.Further comments
To be merged after #39