Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for discrete DOFs #64

Merged
merged 4 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ The blop ``Agent`` takes care of the entire optimization loop, from data acquisi
from blop import DOF, Objective, Agent

dofs = [
DOF(name="x1", description="the first DOF", search_bounds=(-10, 10))
DOF(name="x2", description="another DOF", search_bounds=(-5, 5))
DOF(name="x3", description="ayet nother DOF", search_bounds=(0, 1))
DOF(name="x1", description="the first DOF", search_domain=(-10, 10))
DOF(name="x2", description="another DOF", search_domain=(-5, 5))
DOF(name="x3", description="yet nother DOF", search_domain=(0, 1))
mrakitin marked this conversation as resolved.
Show resolved Hide resolved
]

objective = [
Expand Down
31 changes: 27 additions & 4 deletions docs/source/dofs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A degree of freedom is a variable that affects our optimization objective. We ca

from blop import DOF

dof = DOF(name="x1", description="my first DOF", search_bounds=(lower, upper))
dof = DOF(name="x1", description="my first DOF", search_domain=(lower, upper))

This will instantiate a bunch of stuff under the hood, so that our agent knows how to move things and where to search.
Typically, this will correspond to a real, physical device available in Python. In that case, we can pass the DOF an ophyd device in place of a name
Expand All @@ -16,7 +16,7 @@ Typically, this will correspond to a real, physical device available in Python.

from blop import DOF

dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_bounds=(lower, upper))
dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_domain=(lower, upper))

In this case, the agent will control the device as it sees fit, moving it between the search bounds.

Expand All @@ -27,7 +27,30 @@ In this case, we can define a read-only DOF as

from blop import DOF

dof = DOF(device=a_read_only_ophyd_device, description="a thermometer or something", read_only=True, trust_bounds=(lower, upper))
dof = DOF(device=a_read_only_ophyd_device, description="a thermometer or something", read_only=True, trust_domain=(lower, upper))

and the agent will use the received values to model its objective, but won't try to move it.
We can also pass a set of ``trust_bounds``, so that our agent will ignore experiments where the DOF value jumps outside of the interval.
We can also pass a set of ``trust_domain``, so that our agent will ignore experiments where the DOF value jumps outside of the interval.


Discrete degrees of freedom
---------------------------

In addition to degrees of freedom that vary continuously between a lower and upper bound, we can define discrete degrees of freedom.
One kind is a binary degree of freedom, where the input can take one of two values, e.g.

.. code-block:: python

discrete_dof = DOF(name="x1", description="A discrete DOF", type="discrete", search_domain={"in", "out"})

Another is an ordinal degree of freedom, which takes more than two discrete values but has some ordering, e.g.

.. code-block:: python

ordinal_dof = DOF(name="x1", description="An ordinal DOF", type="ordinal", search_domain={"low", "medium", "high"})

The last is a categorical degree of freedom, which can take many different discrete values with no ordering, e.g.

.. code-block:: python

categorical_dof = DOF(name="x1", description="A categorical DOF", type="categorical", search_domain={"banana", "mango", "papaya"})
4 changes: 2 additions & 2 deletions docs/source/tutorials/himmelblau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
"from blop import DOF\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x1\", search_domain=(-6, 6)),\n",
" DOF(name=\"x2\", search_domain=(-6, 6)),\n",
"]"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
"from blop import DOF, Objective, Agent\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x1\", search_domain=(-6, 6)),\n",
" DOF(name=\"x2\", search_domain=(-6, 6)),\n",
"]\n",
"\n",
"objectives = [\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/passive-dofs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
"\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_bounds=(-5.0, 5.0), active=False),\n",
" DOF(name=\"x1\", search_domain=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_domain=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_domain=(-5.0, 5.0), active=False),\n",
" DOF(device=BrownianMotion(name=\"brownian1\"), read_only=True),\n",
" DOF(device=BrownianMotion(name=\"brownian2\"), read_only=True, active=False),\n",
"]\n",
Expand Down
194 changes: 194 additions & 0 deletions scripts/gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio

import databroker
import matplotlib as mpl
import numpy as np
from bluesky.callbacks import best_effort
from bluesky.run_engine import RunEngine
from databroker import Broker
from nicegui import ui
mrakitin marked this conversation as resolved.
Show resolved Hide resolved

from blop import DOF, Agent, Objective
from blop.utils import functions

# MongoDB backend:
db = Broker.named("temp") # mongodb backend
try:
databroker.assets.utils.install_sentinels(db.reg.config, version=1)
except Exception:
pass

loop = asyncio.new_event_loop()
loop.set_debug(True)
RE = RunEngine({}, loop=loop)
RE.subscribe(db.insert)

bec = best_effort.BestEffortCallback()
RE.subscribe(bec)

bec.disable_baseline()
bec.disable_heading()
bec.disable_table()
bec.disable_plots()


dofs = [
DOF(name="x1", description="x1", search_domain=(-5.0, 5.0)),
DOF(name="x2", description="x2", search_domain=(-5.0, 5.0)),
]

objectives = [Objective(name="himmelblau", target="min")]

agent = Agent(
dofs=dofs,
objectives=objectives,
digestion=functions.himmelblau_digestion,
db=db,
verbose=True,
tolerate_acquisition_errors=False,
)

agent.acqf_index = 0

agent.acqf_number = 2


with ui.pyplot(figsize=(10, 4), dpi=160) as obj_plt:
extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain]

ax1 = obj_plt.fig.add_subplot(131)
ax1.set_title("Samples")
im1 = ax1.scatter([], [], cmap="magma")

ax2 = obj_plt.fig.add_subplot(132, sharex=ax1, sharey=ax1)
ax2.set_title("Posterior mean")
im2 = ax2.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma")

ax3 = obj_plt.fig.add_subplot(133, sharex=ax1, sharey=ax1)
ax3.set_title("Posterior error")
im3 = ax3.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma")

data_cbar = obj_plt.fig.colorbar(mappable=im1, ax=[ax1, ax2], location="bottom", aspect=32)
err_cbar = obj_plt.fig.colorbar(mappable=im3, ax=[ax3], location="bottom", aspect=16)

for ax in [ax1, ax2, ax3]:
ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)


acqf_configs = {
0: {"name": "qr", "long_name": r"quasi-random sampling"},
1: {"name": "qei", "long_name": r"$q$-expected improvement"},
2: {"name": "qpi", "long_name": r"$q$-probability of improvement"},
3: {"name": "qucb", "long_name": r"$q$-upper confidence bound"},
}

with ui.pyplot(figsize=(10, 3), dpi=160) as acq_plt:
extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain]

acqf_plt_objs = {}

for iax, config in acqf_configs.items():
if iax == 0:
continue

acqf = config["name"]

acqf_plt_objs[acqf] = {}

acqf_plt_objs[acqf]["ax"] = ax = acq_plt.fig.add_subplot(1, len(acqf_configs) - 1, iax)

ax.set_title(config["long_name"])
acqf_plt_objs[acqf]["im"] = ax.imshow([[]], extent=extent, cmap="gray_r")
acqf_plt_objs[acqf]["hist"] = ax.scatter([], [])
acqf_plt_objs[acqf]["best"] = ax.scatter([], [])

ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)


acqf_button_options = {index: config["name"] for index, config in acqf_configs.items()}

v = ui.checkbox("visible", value=True)
with ui.column().bind_visibility_from(v, "value"):
ui.toggle(acqf_button_options).bind_value(agent, "acqf_index")
ui.number().bind_value(agent, "acqf_number")


def reset():
agent.reset()

print(agent.table)


def learn():
acqf_config = acqf_configs[agent.acqf_index]

acqf = acqf_config["name"]

n = int(agent.acqf_number) if acqf != "qr" else 16

ui.notify(f"sampling {n} points with acquisition function \"{acqf_config['long_name']}\"")

RE(agent.learn(acqf, n=n))

with obj_plt:
obj = agent.objectives[0]

x_samples = agent.train_inputs().detach().numpy()
y_samples = agent.train_targets(obj.name).detach().numpy()[..., 0]

x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
p = obj.model.posterior(x)

m = p.mean.squeeze(-1, -2).detach().numpy()
e = p.variance.sqrt().squeeze(-1, -2).detach().numpy()

im1.set_offsets(x_samples)
im1.set_array(y_samples)
im1.set_cmap("magma")

im2.set_data(m.T[::-1])
im3.set_data(e.T[::-1])

obj_norm = mpl.colors.Normalize(vmin=np.nanmin(y_samples), vmax=np.nanmax(y_samples))
err_norm = mpl.colors.LogNorm(vmin=np.nanmin(e), vmax=np.nanmax(e))

im1.set_norm(obj_norm)
im2.set_norm(obj_norm)
im3.set_norm(err_norm)

for ax in [ax1, ax2, ax3]:
ax.set_xlim(*agent.dofs[0].search_domain)
ax.set_ylim(*agent.dofs[1].search_domain)

with acq_plt:
x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
x_samples = agent.train_inputs().detach().numpy()

for acqf in acqf_plt_objs.keys():
ax = acqf_plt_objs[acqf]["ax"]

acqf_obj = getattr(agent, acqf)(x).detach().numpy()

acqf_norm = mpl.colors.Normalize(vmin=np.nanmin(acqf_obj), vmax=np.nanmax(acqf_obj))
acqf_plt_objs[acqf]["im"].set_data(acqf_obj.T[::-1])
acqf_plt_objs[acqf]["im"].set_norm(acqf_norm)

res = agent.ask(acqf, n=int(agent.acqf_number))

acqf_plt_objs[acqf]["hist"].remove()
acqf_plt_objs[acqf]["hist"] = ax.scatter(*x_samples.T, ec="b", fc="none", marker="o")

acqf_plt_objs[acqf]["best"].remove()
acqf_plt_objs[acqf]["best"] = ax.scatter(*res["points"].T, c="r", marker="x", s=64)

ax.set_xlim(*agent.dofs[0].search_domain)
ax.set_ylim(*agent.dofs[1].search_domain)


ui.button("Learn", on_click=learn)

ui.button("Reset", on_click=reset)

ui.run(port=8004)
4 changes: 2 additions & 2 deletions src/blop/_version.py
mrakitin marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = "0.6.2.dev0"
__version_tuple__ = version_tuple = (0, 6, 2, "dev0")
__version__ = version = "0.5.1.dev48"
__version_tuple__ = version_tuple = (0, 5, 1, "dev48")
Loading
Loading