Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid silent preequilibration failure in JAX #2631

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,22 @@ class JAXModel(eqx.Module):
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by
classes inheriting from JAXModel.

:ivar api_version:
API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION).
:ivar MODEL_API_VERSION:
API version of the base class.
:ivar jax_py_file:
Path to the JAX model file.
:ivar ss_tol_scale_factor:
Tolerance scale factor for the steady state termination check. Multiplied with tolerances of the user-provided
step size controller.
"""

MODEL_API_VERSION = "0.0.2"
api_version: str
jax_py_file: Path
ss_tol_scale_factor: jnp.float_ = 10.0

def __init__(self):
if self.api_version != self.MODEL_API_VERSION:
Expand Down Expand Up @@ -278,10 +289,23 @@ def _eq(
stepsize_controller=controller,
max_steps=max_steps,
adjoint=diffrax.DirectAdjoint(),
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
event=diffrax.Event(
cond_fn=diffrax.steady_state_event(
rtol=controller.rtol * self.ss_tol_scale_factor,
atol=controller.atol * self.ss_tol_scale_factor,
)
),
throw=False,
)
return sol.ys[-1, :], sol.stats
# If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the
# solution is the last state and the event mask is False. In the latter case, we return inf for the steady
# state.
ys = jnp.where(
sol.event_mask,
sol.ys[-1, :],
jnp.inf * jnp.ones_like(sol.ys[-1, :]),
)
return ys, sol.stats

def _solve(
self,
Expand Down Expand Up @@ -610,6 +634,10 @@ def preequilibrate_condition(

:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param solver:
ODE solver
:param controller:
Expand Down
25 changes: 24 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import diffrax
import numpy as np
from beartype import beartype
from petab.v1.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID

from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem, ReturnValue
from amici.jax import JAXProblem, ReturnValue, run_simulations
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

Expand Down Expand Up @@ -268,6 +269,28 @@ def check_fields_jax(
)


def test_preequilibration_failure(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
# oscillating system, preequilibation should fail when interaction is active
with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert not np.isinf(r[0].item())
petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = (
petab_problem.measurement_df[SIMULATION_CONDITION_ID]
)
with TemporaryDirectoryWinSafe(prefix="failure") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert np.isinf(r[0].item())


@skip_on_valgrind
def test_serialisation(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
Expand Down
Loading