Skip to content

Commit

Permalink
Merge branch 'develop' into jax_sciml
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich authored Dec 6, 2024
2 parents cfb0b5a + 449041d commit 2b698fb
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 77 deletions.
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ filterwarnings =
ignore:Conservation laws for non-constant species in models with Species-AssignmentRules are currently not supported and will be turned off.:UserWarning
ignore:Conservation laws for non-constant species in combination with parameterized stoichiometric coefficients are not currently supported and will be turned off.:UserWarning
ignore:Support for PEtab2.0 is experimental!:UserWarning
ignore:The JAX module is experimental and the API may change in the future.:ImportWarning
# hundreds of SBML <=5.17 warnings
ignore:.*inspect.getargspec\(\) is deprecated.*:DeprecationWarning
# pysb warnings
Expand Down
11 changes: 4 additions & 7 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" results (dict): Simulation results from run_simulations.\n",
" \"\"\"\n",
" # Extract the simulation results for the specific condition\n",
" sim_results = results[simulation_condition][1]\n",
" sim_results = results[simulation_condition]\n",
"\n",
" # Create a new figure for the state trajectories\n",
" plt.figure(figsize=(8, 6))\n",
Expand Down Expand Up @@ -357,27 +357,25 @@
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
"ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
"# Load parameters for the specified condition\n",
"p = jax_problem.load_parameters(simulation_condition[0])\n",
"# Disable preequilibration\n",
"p_preeq = jnp.array([])\n",
"\n",
"\n",
"# Define a function to compute the gradient with respect to dynamic timepoints\n",
"@eqx.filter_jacfwd\n",
"def grad_ts_dyn(tt):\n",
" return jax_problem.model.simulate_condition(\n",
" p=p,\n",
" p_preeq=p_preeq,\n",
" ts_preeq=ts_preeq,\n",
" ts_init=ts_init,\n",
" ts_dyn=tt,\n",
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" x_preeq=jnp.array([]),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" max_steps=2**10,\n",
Expand Down Expand Up @@ -489,7 +487,6 @@
"amici_model = import_petab_problem(\n",
" petab_problem,\n",
" verbose=False,\n",
" compile_=True,\n",
" jax=False, # load the amici model this time\n",
")\n",
"\n",
Expand Down
17 changes: 16 additions & 1 deletion python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
"""Interface to facilitate AMICI generated models using JAX"""
"""
JAX
---
This module provides an interface to generate and use AMICI models with JAX. Please note that this module is
experimental, the API may substantially change in the future. Use at your own risk and do not expect backward
compatibility.
"""

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.model import JAXModel
from amici.jax.nn import generate_equinox

warn(
"The JAX module is experimental and the API may change in the future.",
ImportWarning,
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations", "generate_equinox"]

Check warning on line 22 in python/sdist/amici/jax/__init__.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/__init__.py#L22

Added line #L22 was not covered by tests
64 changes: 42 additions & 22 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,12 @@ def _sigmays(
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
p_preeq: jt.Float[jt.Array, "*np"],
ts_preeq: jt.Float[jt.Array, "nt_preeq"],
ts_init: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand All @@ -445,12 +445,9 @@ def simulate_condition(
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param p_preeq:
parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to
disable pre-equilibration.
:param ts_preeq:
time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated after pre-equilibration.
:param ts_init:
time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated before dynamic simulation.
:param ts_dyn:
time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order.
Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time
Expand Down Expand Up @@ -487,24 +484,16 @@ def simulate_condition(
output according to `ret` and statistics
"""
# Pre-equilibration
if p_preeq.shape[0] > 0:
x0 = self._x0(p_preeq)
tcl = self._tcl(x0, p_preeq)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p_preeq, tcl, current_x, solver, controller, max_steps
)
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(self._x_rdata(current_x, tcl), p)
tcl = self._tcl(x_preeq, p)
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
stats_preeq = None

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(
current_x.reshape(1, -1), ts_preeq.shape[0], axis=0
)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
Expand Down Expand Up @@ -537,7 +526,7 @@ def simulate_condition(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0)
ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
Expand All @@ -556,11 +545,42 @@ def simulate_condition(
}[ret], dict(
ts=ts,
x=x,
stats_preeq=stats_preeq,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param solver:
ODE solver
:param controller:
step size controller
:param max_steps:
maximum number of solver steps
:return:
pre-equilibrated state variables and statistics
"""
# Pre-equilibration
x0 = self._x0(p)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Expand Down
5 changes: 2 additions & 3 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,10 @@ def _generate_jax_code(self) -> None:
for net in self.hybridisation.keys()
),
}
outdir = self.model_path / (self.model_name + "_jax")
outdir.mkdir(parents=True, exist_ok=True)

apply_template(
Path(amiciModulePath) / "jax" / "jax.template.py",
outdir / "__init__.py",
self.model_path / "__init__.py",
tpl_data,
)

Expand Down Expand Up @@ -280,6 +278,7 @@ def set_paths(self, output_dir: str | Path | None = None) -> None:
output_dir = Path(os.getcwd()) / f"amici-{self.model_name}"

self.model_path = Path(output_dir).resolve()
self.model_path.mkdir(parents=True, exist_ok=True)

def set_name(self, model_name: str) -> None:
"""
Expand Down
68 changes: 55 additions & 13 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _get_parameter_mappings(
def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> dict[
tuple[str],
tuple[str, ...],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
"""
Expand Down Expand Up @@ -412,20 +412,22 @@ def run_simulation(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ret: str = "llh",
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation condition.
:param simulation_condition:
Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a
tuple of strings (pre-equilibration followed by simulation).
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:param x_preeq:
Pre-equilibration state if available
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
Expand All @@ -445,19 +447,14 @@ def run_simulation(
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
p_preeq = (
self.load_parameters(simulation_condition[1])
if len(simulation_condition) > 1
else jnp.array([])
)
return self.model.simulate_condition(
p=p,
p_preeq=p_preeq,
ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
x_preeq=x_preeq,
solver=solver,
controller=controller,
max_steps=max_steps,
Expand All @@ -467,10 +464,39 @@ def run_simulation(
ret=ret,
)

def run_preequilibration(
self,
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Run a pre-equilibration simulation for a given simulation condition.
:param simulation_condition:
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:return:
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
return self.model.preequilibrate_condition(
p=p,
solver=solver,
controller=controller,
max_steps=max_steps,
)


def run_simulations(
problem: JAXProblem,
simulation_conditions: Iterable[tuple] | None = None,
simulation_conditions: Iterable[tuple[str, ...]] | None = None,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
rtol=1e-8,
Expand Down Expand Up @@ -513,12 +539,28 @@ def run_simulations(
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}

results = {
sc: problem.run_simulation(sc, solver, controller, max_steps, ret)
sc: problem.run_simulation(
sc,
solver,
controller,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret,
)
for sc in simulation_conditions
}
if ret == "llh":
output = sum(llh for llh, _ in results.values())

Check warning on line 560 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L559-L560

Added lines #L559 - L560 were not covered by tests
else:
output = {sc: res for sc, (res, _) in results.items()}
return output, results
return output, {

Check warning on line 563 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L562-L563

Added lines #L562 - L563 were not covered by tests
sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1]
for sc, res in results.items()
}
24 changes: 22 additions & 2 deletions python/sdist/amici/petab/import_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ def _can_import_model(
Check whether a module of that name can already be imported.
"""
# try to import (in particular checks version)
suffix = "_jax" if jax else ""
try:
model_module = amici.import_model_module(
model_name + suffix, model_output_dir
*_get_package_name_and_path(model_name, model_output_dir, jax)
)
except ModuleNotFoundError:
return False
Expand Down Expand Up @@ -271,3 +270,24 @@ def check_model(
"the current model might also resolve this. Parameters: "
f"{amici_ids_free_required.difference(amici_ids_free)}"
)


def _get_package_name_and_path(
model_name: str, model_output_dir: str | Path, jax: bool = False
) -> tuple[str, Path]:
"""
Get the package name and path for the generated model module.
:param model_name:
Name of the model
:param model_output_dir:
Target directory for the generated model module
:param jax:
Whether to generate the paths for a JAX or CPP model
:return:
"""
if jax:
outdir = Path(model_output_dir)
return outdir.stem, outdir.parent
else:
return model_name, Path(model_output_dir)
Loading

0 comments on commit 2b698fb

Please sign in to comment.